Transformer时间序列预测实战:如何用个人业务数据替换ETTh1进行滚动预测与结果分析
Transformer时间序列预测实战从公开数据集到业务数据的无缝迁移指南当你第一次接触Transformer时间序列预测时可能已经跑通了ETTh1这类公开数据集的Demo。但真正令人头疼的是如何将这套方法迁移到自己的业务数据上本文将带你跨越这道鸿沟从数据格式转换到滚动预测配置手把手实现业务数据的预测落地。1. 业务数据与公开数据集的格式差异解析公开数据集如ETTh1通常已经过标准化处理而业务数据往往存在各种脏数据特征。我们先看一个典型业务CSV与ETTh1的结构对比特征ETTh1数据集业务数据常见问题时间列标准datetime格式可能缺失或格式不统一目标列明确标记为OT需要人工指定目标变量缺失值已处理存在间断或异常值频率明确每小时(h)可能不均匀采样多变量7个相关特征特征相关性未知关键转换步骤# 业务数据预处理示例 import pandas as pd # 读取原始业务数据 raw_data pd.read_csv(sales_data.csv) # 时间列标准化 raw_data[date] pd.to_datetime(raw_data[timestamp]).dt.floor(h) # 按小时对齐 # 处理缺失值 data raw_data.set_index(date).interpolate().reset_index() # 重命名目标列假设销售额是预测目标 data.rename(columns{total_sales: OT}, inplaceTrue) # 保存为模型可读格式 data.to_csv(processed_business_data.csv, indexFalse)注意业务数据的features参数通常选择MS多变量预测单变量因为实际场景中我们往往需要利用所有可用特征来预测核心指标。2. 模型参数的业务化改造策略原始代码中的参数配置需要针对业务数据做针对性调整。以下是关键参数的业务适配指南2.1 时间相关参数parser.add_argument(--freq, typestr, defaulth, help业务数据常见选项: h(小时), d(天), b(工作日)) parser.add_argument(--seq_len, typeint, default168, help建议设置为业务周期整数倍如零售业7天周期可用168(小时)) parser.add_argument(--pred_len, typeint, default24, help根据业务需求设定如预测未来1天设为24)2.2 数据特征配置# 查看业务数据特征数量 import pandas as pd df pd.read_csv(processed_business_data.csv) num_features len(df.columns) - 1 # 减去时间列 parser.add_argument(--enc_in, typeint, defaultnum_features, help编码器输入尺寸业务特征总数) parser.add_argument(--dec_in, typeint, defaultnum_features, help解码器输入尺寸业务特征总数) parser.add_argument(--c_out, typeint, default1, help输出维度通常为1(单变量预测))2.3 训练策略优化parser.add_argument(--train_epochs, typeint, default50, help业务数据通常需要更多训练轮次) parser.add_argument(--batch_size, typeint, default32, help根据GPU内存调整业务数据可能更大) parser.add_argument(--learning_rate, typefloat, default0.0001, help业务数据建议更小的学习率)3. 滚动预测的业务落地技巧滚动预测(rolling forecast)是业务场景中最实用的预测方式其核心在于模拟实时预测环境。我们通过分步拆解实现这一过程3.1 滚动预测数据准备将业务数据按时间排序后分割训练集前80%数据验证集中间10%数据测试集最后10%数据用于滚动预测测试集需要保持与训练集完全相同的特征顺序和格式3.2 滚动预测参数配置parser.add_argument(--rollingforecast, typebool, defaultTrue) parser.add_argument(--rolling_data_path, typestr, defaultbusiness_data_test.csv) parser.add_argument(--label_len, typeint, default72, help建议设置为seq_len的1/3到1/2)3.3 预测结果后处理滚动预测会产生多个预测片段需要拼接并添加时间戳def merge_rolling_forecasts(predictions, test_data): timestamps test_data[date].iloc[-len(predictions):] result pd.DataFrame({ timestamp: timestamps, actual: test_data[OT].iloc[-len(predictions):], predicted: predictions.flatten() }) return result # 保存预测结果 merged_results.to_csv(rolling_forecast_results.csv, indexFalse)4. 业务预测结果的可视化分析不同于学术研究业务预测需要更直观的可视化来支持决策。推荐以下几种专业级可视化方式4.1 动态误差带展示import plotly.graph_objects as go fig go.Figure() fig.add_trace(go.Scatter( xresults[timestamp], yresults[actual], name实际值, linedict(colorblue) )) fig.add_trace(go.Scatter( xresults[timestamp], yresults[predicted], name预测值, linedict(colorred) )) fig.add_trace(go.Scatter( xresults[timestamp], yresults[predicted]*1.1, fillNone, modelines, linedict(width0), showlegendFalse )) fig.add_trace(go.Scatter( xresults[timestamp], yresults[predicted]*0.9, filltonexty, modelines, linedict(width0), name误差范围 )) fig.update_layout(title业务预测结果对比带10%误差范围) fig.show()4.2 关键指标计算表指标公式业务意义MAPE$\frac{100%}{n}\sum...$平均百分比误差RMSE$\sqrt{\frac{1}{n}\sum...$对异常值敏感的绝对误差业务达标率预测误差5%的样本占比直接反映预测可用性# 关键指标计算代码 def business_metrics(actual, predicted): mape np.mean(np.abs((actual - predicted)/actual)) * 100 rmse np.sqrt(np.mean((actual - predicted)**2)) 达标率 np.mean(np.abs((actual - predicted)/actual) 0.05) * 100 return {MAPE: mape, RMSE: rmse, 达标率: 达标率}5. 业务场景中的特殊问题处理实际业务部署时会遇到一些公开数据集中不常见的问题这里提供解决方案5.1 间断性业务数据对于零售业等存在营业时间断层的场景# 创建营业时间掩码 business_hours (data[timestamp].dt.hour 9) (data[timestamp].dt.hour 21) data[valid] business_hours.astype(int) # 在DataEmbedding中添加掩码处理 class BusinessDataEmbedding(nn.Module): def __init__(self, c_in, d_model, dropout0.1): super().__init__() self.value_embedding nn.Linear(c_in, d_model) self.valid_embedding nn.Embedding(2, d_model) # 0/1两种状态 def forward(self, x, x_mark): # x_mark包含valid列 val_embed self.value_embedding(x) valid_embed self.valid_embedding(x_mark[:,-1].long()) return val_embed valid_embed5.2 多周期特征融合业务数据往往包含多个周期特征日周期、周周期等# 在数据预处理阶段添加周期特征 data[day_of_week] data[timestamp].dt.dayofweek data[hour_of_day] data[timestamp].dt.hour data[is_weekend] data[day_of_week] 5 # 修改模型参数 parser.add_argument(--embed, typestr, defaulttimeF, help使用时间特征编码) parser.add_argument(--embed_type, typeint, default0, help启用完整的时间嵌入)6. 模型效果不佳时的业务调优策略当预测效果不理想时可以尝试以下业务导向的优化方法6.1 特征工程增强业务知识特征添加促销活动标记、节假日标记等衍生特征创建同比/环比特征、移动平均特征外部特征整合天气数据、经济指标等外部数据源6.2 模型结构调整# 更适合业务数据的Transformer变体配置 parser.add_argument(--n_heads, typeint, default4, help业务数据通常需要更多注意力头) parser.add_argument(--e_layers, typeint, default3, help更深的编码器捕捉复杂业务模式) parser.add_argument(--d_ff, typeint, default1024, help减小中间维度防止业务数据过拟合)6.3 预测结果后校准建立误差校正模型# 使用预测误差训练校正模型 from sklearn.ensemble import GradientBoostingRegressor # 准备校正训练数据 X_correct predictions[[predicted, hour, day_of_week]] y_correct predictions[actual] - predictions[predicted] # 训练校正模型 corrector GradientBoostingRegressor().fit(X_correct, y_correct) # 应用校正 predictions[corrected] predictions[predicted] corrector.predict(X_correct)7. 业务预测系统的持续优化建立预测监控体系是业务落地的关键性能看板实时显示预测准确率和业务影响自动重训机制当预测误差连续3天超过阈值时自动触发模型重训AB测试框架对比不同模型版本的业务指标提升# 自动化监控示例 class PredictionMonitor: def __init__(self, threshold0.1): self.error_window [] self.threshold threshold def update(self, actual, predicted): error np.mean(np.abs(actual - predicted)/actual) self.error_window.append(error) if len(self.error_window) 3: self.error_window.pop(0) if len(self.error_window) 3 and all(e self.threshold for e in self.error_window): self.trigger_retraining() def trigger_retraining(self): # 实现自动重训逻辑 print(触发模型重训...)