Page QiView

Temporal Fusion Transformer:可解释多变量预测模型

Temporal Fusion Transformer:可解释多变量预测模型

1. TFT 解决的痛点

多变量业务预测常见两难:

  1. 深度模型精度高但难解释;
  2. 经典模型可解释但表达能力弱。

TFT 通过“门控 + 变量选择 + 注意力”实现兼顾。

2. 结构要点

  1. 变量选择网络(Variable Selection Network)。
  2. 静态协变量编码。
  3. LSTM 局部时序建模。
  4. 多头注意力捕获长程依赖。
  5. 分位数损失输出概率预测。

分位数损失:

$$ \mathcal{L}_q(y,\hat{y})=\max{q(y-\hat{y}),(q-1)(y-\hat{y})} $$

3. Python 示例(PyTorch Forecasting)

from pytorch_forecasting import TemporalFusionTransformer

# training: TimeSeriesDataSet 构建后的 dataloader
model = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=1e-3,
    hidden_size=32,
    attention_head_size=4,
    dropout=0.1,
    output_size=7,  # 分位数个数
)

4. 解释输出怎么用

  1. 看变量重要性:哪些特征在不同阶段主导预测。
  2. 看注意力权重:模型关注了哪些历史时点。
  3. 看分位数区间:把风险区间而非单点传给业务。

5. 常见错误

  1. 时间切分不严格,造成未来信息泄漏。
  2. 只看单一指标,不看覆盖率与分位数校准。
  3. 外生变量未来值不可得却在训练中可见。

TFT 的真正价值在于“可沟通的深度学习预测”,非常适合跨部门协作的运营场景。