您好,登錄后才能下訂單哦!
object SparkSqlTest {
def main(args: Array[String]): Unit = {
//屏蔽多余的日志
Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.project-spark").setLevel(Level.WARN)
//構建編程入口
val conf: SparkConf = new SparkConf()
conf.setAppName("SparkSqlTest")
.setMaster("local[2]")
val spark: SparkSession = SparkSession.builder().config(conf)
.getOrCreate()
//創建sqlcontext對象
val sqlContext: SQLContext = spark.sqlContext
/**
* 注冊定義的UDF:
* 這里的泛型[Int,String]
* 第一個是返回值類型,后面可以是一個或者多個,是方法參數類型
*/
sqlContext.udf.register[Int,String]("strLen",strLen)
val sql=
"""
|select strLen("zhangsan")
""".stripMargin
spark.sql(sql).show()
}
//自定義UDF方法
def strLen(str:String):Integer={
str.length
}
}
這里舉的例子是實現一個count:
自定義UDAF類:
class MyCountUDAF extends UserDefinedAggregateFunction{
//該UDAF輸入的數據類型
override def inputSchema: StructType = {
StructType(List(
StructField("age",DataTypes.IntegerType)
))
}
//在該UDAF中聚合的數據類型
override def bufferSchema: StructType = {
StructType(List(
StructField("age",DataTypes.IntegerType)
))
}
//該UDAF輸出的數據類型
override def dataType: DataType = DataTypes.IntegerType
//確定性判斷,通常特定輸入和輸出的類型一致
override def deterministic: Boolean = true
//buffer:計算過程中臨時的存儲了聚合結果的Buffer
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0)
}
/**
* 分區內的數據聚合合并
* @param buffer:就是我們在initialize方法中聲明初始化的臨時緩沖區
* @param input:聚合操作新傳入的值
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val oldValue=buffer.getInt(0)
buffer.update(0,oldValue+1)
}
/**
* 分區間的聚合
* @param buffer1:分區一聚合的臨時結果
* @param buffer2;分區二聚合的臨時結果
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val p1=buffer1.getInt(0)
val p2=buffer2.getInt(0)
buffer1.update(0,p1+p2)
}
//該聚合函數最終輸出的值
override def evaluate(buffer: Row): Any = {
buffer.get(0)
}
}
調用:
object SparkSqlTest {
def main(args: Array[String]): Unit = {
//屏蔽多余的日志
Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.project-spark").setLevel(Level.WARN)
//構建編程入口
val conf: SparkConf = new SparkConf()
conf.setAppName("SparkSqlTest")
.setMaster("local[2]")
.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
.registerKryoClasses(Array(classOf[Student]))
val spark: SparkSession = SparkSession.builder().config(conf)
.getOrCreate()
//創建sqlcontext對象
val sqlContext: SQLContext = spark.sqlContext
//注冊UDAF
sqlContext.udf.register("myCount",new MyCountUDAF())
val stuList = List(
new Student("委xx", 18),
new Student("吳xx", 18),
new Student("戚xx", 18),
new Student("王xx", 19),
new Student("薛xx", 19)
)
import spark.implicits._
val stuDS: Dataset[Student] = sqlContext.createDataset(stuList)
stuDS.createTempView("student")
val sql=
"""
|select myCount(1) counts
|from student
|group by age
|order by counts
""".stripMargin
spark.sql(sql).show()
}
}
case class Student(name:String,age:Int)
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。