主页 > IT业界  > 

将pyspark中的UDF提升6倍

将pyspark中的UDF提升6倍
本文亮点

      调用jar中的UDF,减少python与JVM的交互,简单banchmark下对于54亿条数据集进行udf计算,从3小时的执行时间缩短至16分钟。 牺牲UDF部分的开发时间,尽量提高性能。 以接近纯python的开发成本,获得逼近纯scala的性能。兼顾性能和开发效率。

前提

       当遇到sql无法直接处理的数据时(比如加密解密、thrift解析操作二进制),我们需要自定义函数(UDF)来进行处理。出于开发效率的考虑,我们一般会选择airflow,使用pyspark脚本。

优化后的代码

from datetime import datetime from pyspark.sql import SparkSession import sys if __name__ == "__main__": # 创建 Spark 会话 spark = SparkSession.builder.appName("xxx").enableHiveSupport().getOrCreate() DT = sys.argv[1] HOUR = sys.argv[2] F_DT = datetime.strptime(DT, "%Y-%m-%d").strftime("%Y%m%d") tmp_tbl_name = f"""temp_xxx_pre_0_{F_DT}_{HOUR}""" print('''注册java函数''') spark.udf.registerJavaFunction("get_area_complex_value", "com.xxx.utf.block.AreaComplexValue") spark.udf.registerJavaFunction("most_common_element", "com.xxx.utf mon.MosCommonElement") spark.sql('''set hive.exec.dynamic.partition.mode=nonstrict;''') exec_sql = ''' insert overwrite table xxx.xxx --你的目标表 select a.distinct_id device_id -- l临时 ,app_id ,install_datetime ,os ,ip ,country ,city ,uuid ,a.distinct_id ,event_name ,event_timestamp ,event_datetime ,game_id ,game_type ,round_id ,travel_id ,travel_lv ,matrix ,position ,concat('[', concat_ws(',', clean), ']') as clean ,block_id ,index_id ,rec_strategy ,rec_strategy_fact ,combo_cnt ,gain_score ,gain_item ,block_list ,lag_event_timestamp ,(event_timestamp - lag_event_timestamp) / 1000 AS time_diff_in_seconds ,most_common_element( replace( replace(replace(rec_strategy_fact, '[', ''), ']', ''), '"', '' )) as rec_strategy_fact_most ,lag_matrix ,is_clear_screen ,is_blast ,blast_row_col_cnt ,CASE WHEN size(clean) > 0 THEN CASE WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) > 0 THEN TRUE WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0 AND combo_cnt = '-1' AND (IF(size(lead_clean_3) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0) + IF(size(lead_clean_1) > 0, 1, 0)) > 0 THEN TRUE ELSE FALSE END ELSE CASE WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0 THEN FALSE WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 2 THEN TRUE WHEN size(lag_clean_1) > 0 AND lag_combo_cnt_1 = -1 AND (IF(size(lead_clean_1) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0)) >= 1 THEN TRUE WHEN size(lag_clean_1) > 0 AND lag_combo_cnt_1 > -1 AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0)) >= 1 THEN TRUE WHEN size(lag_clean_2) > 0 AND lag_combo_cnt_2 = -1 AND size(lead_clean_1) > 0 THEN TRUE WHEN size(lag_clean_2) > 0 AND lag_combo_cnt_2 > -1 AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0) + IF(size(lag_clean_5) > 0, 1, 0)) >= 1 THEN TRUE ELSE FALSE END END AS is_combo_status ,common_block_cnt ,step_score ,block_index_id ,cast(get_area_complex_value(lag_matrix) as int) matrix_complex_value ,-0.1 as block_line_percent ,-0.1 as corner_outside_percent ,-0.1 as corner_inside_percent ,sum((event_timestamp - lag_event_timestamp) / 1000) over(partition by game_id,game_type,distinct_id) time_accumulate_seconds -- double COMMENT '当前块距离本局游戏的开始的时长,秒级别的' ,max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) max_round_id -- 最大轮数2024-12-11数据开始准确 ,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) then true else false end is_final_round -- boolean COMMENT '此轮是不是最后一轮' 最后一轮 2024-12-11数据开始准确 ,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) and block_index_id=max(block_index_id) over(partition by game_id,game_type,distinct_id,round_id ) then true else false end is_lethal_block -- boolean COMMENT '是不是致死块' 最后一轮里面的最后一个放块 2024-12-11数据开始准确 ,sum(step_score) over(partition by game_id,game_type,distinct_id) - step_score lag_accumulate_score -- double COMMENT '出此块前的累计分数' ,sum(step_score) over(partition by game_id,game_type,distinct_id) accumulate_score -- double COMMENT '出此块后的累计分数' ,cast((event_timestamp-last_click_time) / 1000 as int) as time_action_in_seconds -- 落块动作的时间 ,(event_timestamp - lag_event_timestamp) / 1000 - (event_timestamp-last_click_time) / 1000 as time_think_in_seconds --落块-思考时间 ,0 as last_click_time ,cast(`gain_score_per_done` as Integer) as gain_score_per_done ,cast(`is_clean_screen` as Integer) as is_clean_screen ,cast(`weight` as float) as weight ,cast(`put_rate` as float) as put_rate ,0 as userwaynum ,clean_times ,clean_cnt ,sum(clean_times) over(partition by game_id,game_type,distinct_id) as accumulate_clean_times ,sum(clean_cnt) over(partition by game_id,game_type,distinct_id) as accumulate_clean_cnt ,app_version ,ram ,disk ,cast(get_area_complex_value(matrix) as int) cur_matrix_complex_value ,block_down_color ,design_position ,1 as is_sdk_sample ,network_type ,session_id ,block_shape_list ,block_shape ,design_postion_upleft ,fps ,-1 as fact_line ,case when block_id in (1) then 4 when block_id in (2,3) then 6 when block_id in (4 ,5 ,6 ,9 ,15 ,27 ,28 ,37 ,38 ) then 8 when block_id in (7 ,8 ,10 ,14 ,16 ,17 ,18 ,19 ,20 ,25 ,26 ,29 ,30 ,31 ,32 ,33 ,34 ,35 ,36 ,42) then 10 when block_id in (11, 12 ,13 ,21 ,22 ,23 ,24 ,39 ,40 ,41) then 12 end as total_line ,dt from xxx.xxx a --你的源表 where dt='{DT}' and event_name = 'game_touchend_block_done' and (event_timestamp - lag_event_timestamp)>0; '''.format(DT=DT, tmp_tbl_name=tmp_tbl_name) print(exec_sql) spark.sql(exec_sql) # 关闭 Spark 会话 spark.stop() 低层实现原理

如上图所示,pyspark并没有像dpark一样用python重新实现一个计算引擎,依旧是复用了scala的jvm计算底层,只是用py4j架设了一条python进程和jvm互相调用的桥梁。 driver: pyspark脚本和sparkContext的jvm使用py4j相互调用; executor: 由于driver帮忙把spark算子封装好了,执行计划也生成了字节码,一般情况下不需要python进程参与,仅当需要运行UDF(含lambda表达式形式)时,将它委托给python进程处理(DAG图中的BatchEvalPython步骤),此时JVM和python进程使用socket通信。

上述使用简单UDF时的pyspark由于需要使用UDF,因此DAG图中有BatchEvalPython步骤:

BatchEvalPython过程

参考源码:spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala at master · apache/spark · GitHub 可以看到和这个名字一样直白,它就是每次取100条数据让python进程帮忙处理一下:

// 第58行: // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. val inputIterator = iter.map { row => if (needConversion) { EvaluatePython.toJava(row, schema) } else { // fast path for these types that does not need conversion in Python val fields = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { val dt = dataTypes(i) fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) i += 1 } fields } }.grouped(100).map(x => pickle.dumps(x.toArray))

由于我们的计算任务一般耗时瓶颈在于executor端的计算而不是driver,因此应该考虑尽量减少executor端调用python代码的次数从而优化性能。

参考源码:spark/python/pyspark/java_gateway.py at master · apache/spark · GitHub

// 大概135行的地方: # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2")

pyspark可以把很多常见的运算封装到JVM中,但是显然不包括我们的UDF。 所以一个很自然的思路就是把我们的UDF也封到JVM中。

将python的自定义函数改成java

github /sunlongjiang/adx/blob/master/adx_platform_common/src/main/java/com/hungrystudio/utf/block/AreaComplexValue.java

并在任务中通过--jars 引用该jar包

"--jars", "s3://hungry-studio-data-warehouse/user/sunlj/java_udf/adx_platform_common-4.0.0.jar",

改写后运行任务发现比之前少了两个transform,没有了BatchEvalPython,也少了一个WholeStageCodeGen。

优化前该任务的执行时长为3个小时,

优化后改任务的执行时长为16分钟,效果非常明显!!!

因此在pyspark中尽量使用spark算子和spark-sql,同时尽量将UDF(含lambda表达式形式)封装到一个地方减少JVM和python脚本的交互。 由于BatchEvalPython过程每次处理100行,也可以把多行聚合成一行减少交互次数。 最后还可以把UDF部分用java重写打包成jar包,其他部分则保持python脚本以获得不用编译随时修改的灵活性,以兼顾性能和开发效率。

标签:

将pyspark中的UDF提升6倍由讯客互联IT业界栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“将pyspark中的UDF提升6倍