Temporal Fusion Transformer:可解释多变量预测模型
1. TFT 解决的痛点
多变量业务预测常见两难:
- 深度模型精度高但难解释;
- 经典模型可解释但表达能力弱。
TFT 通过“门控 + 变量选择 + 注意力”实现兼顾。
2. 结构要点
- 变量选择网络(Variable Selection Network)。
- 静态协变量编码。
- LSTM 局部时序建模。
- 多头注意力捕获长程依赖。
- 分位数损失输出概率预测。
分位数损失:
$$ \mathcal{L}_q(y,\hat{y})=\max{q(y-\hat{y}),(q-1)(y-\hat{y})} $$
3. Python 示例(PyTorch Forecasting)
from pytorch_forecasting import TemporalFusionTransformer
# training: TimeSeriesDataSet 构建后的 dataloader
model = TemporalFusionTransformer.from_dataset(
training,
learning_rate=1e-3,
hidden_size=32,
attention_head_size=4,
dropout=0.1,
output_size=7, # 分位数个数
)
4. 解释输出怎么用
- 看变量重要性:哪些特征在不同阶段主导预测。
- 看注意力权重:模型关注了哪些历史时点。
- 看分位数区间:把风险区间而非单点传给业务。
5. 常见错误
- 时间切分不严格,造成未来信息泄漏。
- 只看单一指标,不看覆盖率与分位数校准。
- 外生变量未来值不可得却在训练中可见。
TFT 的真正价值在于“可沟通的深度学习预测”,非常适合跨部门协作的运营场景。