为什么要用高斯过程回归
- 现实生活中,我们遇到的一个典型问题就是选择合适的模型拟合训练集中自变量 $X$ 与因变量 $y$ 之间的关系,并根据新的自变量 $x$ 来预测对应的因变量 $f$
$$p(f|x,X,y)$$ - 如果关系足够简单,那么线性回归就能实现很好的预测,但现实情况往往十分复杂,此时,高斯过程回归就为我们提供了拟合复杂关系(quadratic, cubic, or even nonpolynomial)的绝佳方法
什么是高斯过程回归
- 高斯过程可以看做是多维高斯分布向无限维的扩展,我们可以将 $y={y_1,y_2,…,y_n}$ 看作是从 $n$ 维高斯分布中随机抽取的一个点
- 对高斯过程的刻画,如同高斯分布一样,也是用均值和方差来刻画。通常在应用高斯过程 $f \sim GP(m,K)$ 的方法中,都是假设均值 $m$ 为零,而协方差函数 $K$ 则是根据具体应用而定
- 高斯回归的本质其实就是通过一个映射把自变量从低维空间映射到高维空间(类似于支持向量机中的核函数将低维线性不可分映射为高维线性可分),只需找到合适的核函数,就可以知道 $p(f|x,X,y)$ 的分布,最常用的就是高斯核函数
高斯过程回归的基本流程
- 再利用高斯过程回归时,不需要指明 $f(x)$ 的具体形式,如线性 $f(x)=mx+c$,或者二次 $f(x)=ax^2+bx+c$ 等具体形式,n 个训练集的观测值 ${y_1,y_2,…,y_n}$ 会被看做多维(n 维)高斯分布中采样出来的一个点
- 现在给定训练集 ${x_1,x_2,…,x_n}$ 与对应的观测值 ${y_1,y_2,…,y_n}$,由于观测通常是带噪声的,所以将每个观测 $y$ 建模为某个隐函数 $f(x)$ 加上一个高斯噪声,即
$$y=f(x)+N(0,\sigma_n^2)$$ - 其中,$f(x)$ 被假定给予一个高斯过程先验,即
$$f(x) \sim GP(0,K)$$ - 其中协方差函数 $k(x,x’)$ 可以选择不同的单一形式,也可以采用协方差函数的组合形式,由于假设均值为零,因此最后结果的好坏很大程度上取决于协方差函数的选择。不同的协方差函数形式参见这篇文章对 Covariance Functions 的详细介绍。常见的协方差函数如下,参见 Wikipedia-Gaussian Process
- 根据高斯分布的性质以及测试集和训练集数据来自同一分布的特点,可以得到训练数据与测试数据的联合分布为高维的高斯分布,有了联合分布就可以比较容易地求出预测数据 $y^\ast$ 的条件分布 $p(y^\ast|y)$,对 $y^\ast$ 的估计,我们就用分布的均值来作为其估计值,具体推导参见 Reference
利用高斯过程进行时间序列预测
- R 中
kernlab
包的gausspr
函数可以进行高斯回归,并实现预测,以下面这个包含 46 个月的时间序列 ts7 为例
利用趋势回归并进行预测
|
|
只利用趋势项进行高斯回归的拟合效果如下
然后用过去三年的时间序列作为训练集对未来一个月的需求进行循环预测
123456temp1 <- data.table(ts7, fitted=0)for (k in 0:9){train <- temp1[(1+k):(36+k),3:4]fit <- gausspr(demand~t, data=train)temp1[(37+k), "fitted"] <- predict(fit, temp1[(37+k),.(t)])}
利用趋势+季节回归并进行预测
- 首先,去除趋势之后,检查去趋势之后的时间序列是否具有明显的季节性,并找出 CV 最小的前三个季节12temp$demand_detrend <- temp$demand - temp$fittedggplot(temp, aes(x=year_month, group=1)) + geom_line(aes(y=demand, col="demand"), size=1) + geom_line(aes(y=fitted, col="fitted"), size=1) + geom_line(aes(y=demand_detrend, col="demand_detrend")) + theme_bw() + theme(axis.text.x=element_text(angle=45,hjust=1,vjust=1)) + scale_x_discrete(breaks=temp$year_month[seq(2,44,3)])
- 季节性雷达图1ggseasonplot(temp$demand_detrend, polar=TRUE) + ggtitle("Seasonal Plot") + geom_line(size=1) + theme_bw()
- 季节性箱形图1ggplot(temp, aes(x=month, y=demand_detrend)) + geom_boxplot() + theme_bw()
获取季节 cv 最小的前 3 个季节分别是12月、2月、10月
12as.numeric(temp[, .(cv=sd(demand_detrend)/mean(abs(demand_detrend))), by=month][order(cv)][1:3,month])# [1] 12 2 10加入全部 12 个月作为季节性之后,再对最后的 10 个月进行循环预测
123456temp2 <- data.table(ts7, fitted=0)for (k in 0:9){train <- temp2[(1+k):(36+k),3:15]fit <- gausspr(demand~., data=train)temp2[(37+k), "fitted"] <- predict(fit, temp2[(37+k),4:15])}
预测结果比较
|
|
总结
- 当随机变量呈现明显的非线性趋势时,高斯过程回归能够很好地预测线性预测的不足
- 季节性并不一定能够提高预测效果,当某些月份的需求变动幅度很大时,加入季节虚拟变量反而会增大预测误差
- 高斯过程不仅能用于回归预测,还能用于解决分类问题,有兴趣的读者请自行探究
Reference