With the introduction of the pivot function within Spark 1.6.0, I thought I’ll give implementing a simple version of melt a go. Currently it isn’t as flexible as the reshape2 library within R but it already does a pretty good job following the same approach to which the reshape library does it in.

The essential idea behind the code is using flatMap functionality on DataFrame objects to emit multiple rows (observations) per each row in the data frame and remapping the resulting values.

As a simple example, consider the following:

from pyspark.sql import Row

def simple_melt(row):
    # usage: `sqlContext.createDataFrame(df.flatMap(simple_melt))`
    return [Row(A=row[0], B="1", C=row[1]),
            Row(A=row[0], B="2", C=row[2]),
            Row(A=row[0], B="3", C=row[3])]

This will take in each row, and melt it by the first variable for the next 3.

We can generalize this by determining what the field names are for each row object, and selecting the id variables which we want to melt by. This will result in the melt function:

def melt(row, ids, var_name='variable', value_name='value'):
    """takes in a row object and melts it keeping ids the same;
    
    based on this: http://sinhrks.hatenablog.com/entry/2015/04/29/085353
    """
    for id in ids:
        if id not in row.__fields__:
            raise ValueError(id + ": is not found in the list of fields")
    
    row_names = row.__fields__[:]
    variable_fields = set(row_names) - set(ids)
    melted_rows = []
    
    for var in variable_fields:
        curr_var = {}
        for id in ids: 
            curr_var[id] = row[row_names.index(id)] # creates the by ids part
        curr_var[var_name] = var
        curr_var[value_name] = row[row_names.index(var)]
        melted_rows.extend(Row(curr_var))
    return melted_rows

Usage would be:


df_melted = sqlContext.createDataFrame(df.flatMap(lambda row: melt(row, ids=ids)))