Spark UDF

pyspark.sql.functions 提供了很多预定义的函数用来对列数据进行处理,有数学函数、agg相关函数、字符串处理函数、列编解码函数、时间相关函数等。自定义函数(user defined function)顾名思义,我们在使用pyspark的过程在进行具体业务分析时难免会遇到内置函数无法满足需求的情况,这时候就需要使用到pyspark中的udf功能。

本节我们将编写简单的UDF,提取日期(如03-12-2019)中的年份(2019)

准备数据集

打开Jupyter,下载数据库集:

!wget https://pingfan.s3.amazonaws.com/files/transactions.csv
!wget https://pingfan.s3.amazonaws.com/files/customers.csv

image-20220607130744100

加载数据到DataFrame:

from pyspark.sql import SparkSession

spark = SparkSession \
       .builder \
       .appName("FirstApp") \
       .getOrCreate()

df = spark.read.csv("customers.csv",header=True, inferSchema = True)
df.show()

image-20220607131435467

使用UDF

现在我们想提取日期(如22-11-2019)中的年份(2019)。如果使用python函数,可以这样写:

def get_year(date):
	  return date.split('-')[2]

在pyspark中,原理是一样的:

from pyspark.sql.functions import udf

extract_year = udf (lambda Date:Date.split('-')[2])
df1 = df.withColumn("year",extract_year(df.Date))
df1.show()

使用UDF,我们成功地将年份数据提取出来:

image-20220607131632687