Meta-level operations on a Spark data frame
- 3 minutes read - 548 wordsUse case
Sometimes you want to apply an operation on a subset of columns of a data frame, like trimming all string colums or converting all empty strings into a null value.
This is typically done in the extract phase when data flows into your ETL pipeline.
The question is now how can this be done in a generic way? This means without writing code for each record (most likely you have lots of different records to represent your entities). And since we are using Scala, how can this be accomplished in an immutable way?
Sample data set
These data structures are used in the examples:
case class Rec(name: String, city: String, age: Int)
val xs =
List(
Rec("Alice", "Amsterdam", 20),
Rec("Bob", "Berlin", 30),
Rec("Charlie", "Chicago", 40)
)
val ds: Dataset[Rec] = spark.createDataset(xs)
val schema = ds.schema
val columnNames = schema
.fields
.collect { case f if f.dataType == StringType => f.name }
.toList
The schema (meta data) of the data set is accessible in the schema property of a Spark Dataset. The snippet above collects the name of all strings columns into the columnNames value.
Option 1: Tail recursion
The 1st option is to use tail recursion:
import org.apache.spark.sql.functions.upper
@tailrec
def toUpperCase(df: DataFrame,
columnNames: List[String]): DataFrame = {
columnNames match {
case Nil =>
df
case h :: t =>
toUpperCase(df.withColumn(h, upper(df(h))), t)
}
}
val resultDf = toUpperCase(ds.toDF, columnNames)
resultDf.show()
resultDf.as[Rec].show
This snippet passes the list of string column names into the tail recursion function toUpperCase().
toUpperCase() is doing head-tail-decomposition on the columns: The list is split into h (the head element) and the t (the tail; the list of all remaining elements). So it picks the head of the column list, applies the upper() function on it, and then calls itself recursively with the updated data frame and the remaining columns (the tail).
The @tailrec annotation lets the compiler check, if tail call optimization is available. When tail call optimization is possible, the compiler will effectively convert it into a normal loop. This is faster, and avoid stack overflow problems.
Option 2: Using a fold
As an alternative to tail recursion you could use a fold which is provided in the Scala collection library:
import org.apache.spark.sql.functions.upper
val resultDf = columnNames.foldLeft(ds.toDF) {
case (df, columnName) =>
df.withColumn(columnName, upper(df(columnName)))
}
resultDf.show()
resultDf.as[Rec].show
This snippet folds over the column names (this is like a loop over each column name).
It uses the data frame as the initial value (1st argument to the foldLeft() function).
The data frame is used for result aggregation during the fold operation as well. The aggregated result so far appears as the 1st value in the pattern match (see the case line).
The 2nd argument to the pattern match is the column name,
Each case returns the new aggregation result, in this case it’s the updated data frame.
foldLeft() is used in this case to preserve the order of the columns.
Conclusion
The version using the fold is shorter, and in my opinion it transports the intention a little bit better: A fold defines things like initial value, aggregation value etc., but this might just be a matter of taste.
Anyway these are two possibilities to process data frames in a generically; and it’s possible to do this in an immutable way as well.