""" TFT (Temporal Fusion Transformer) 模型实现 简化版本,支持静态特征、已知未来特征和观测时序特征 """ import torch import torch.nn as nn import torch.nn.functional as F import math class TFTEncoder(nn.Module): """TFT编码器""" def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1): super(TFTEncoder, self).__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers) def forward(self, x, mask=None): """ 前向传播 参数: x: 输入 [batch_size, seq_len, d_model] mask: 注意力mask [batch_size, seq_len] 返回: output: 编码输出 [batch_size, seq_len, d_model] """ return self.encoder(x, src_key_padding_mask=mask) class TFTDecoder(nn.Module): """TFT解码器""" def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout=0.1): super(TFTDecoder, self).__init__() decoder_layer = nn.TransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True ) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): """ 前向传播 参数: tgt: 目标序列 [batch_size, tgt_len, d_model] memory: 编码器输出 [batch_size, seq_len, d_model] tgt_mask: 目标mask memory_mask: 记忆mask 返回: output: 解码输出 [batch_size, tgt_len, d_model] """ return self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_mask) class TemporalFusionTransformer(nn.Module): """ Temporal Fusion Transformer (简化版) 支持静态特征、已知未来特征和观测时序特征 """ def __init__(self, num_observed_features, num_static_features, num_known_future_features, num_output_features=None, hidden_size=128, num_heads=4, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1): """ 初始化TFT 参数: num_observed_features: 观测时序特征维度(来自Phased LSTM的输出) num_static_features: 静态特征维度 num_known_future_features: 已知未来特征维度 hidden_size: 隐藏层维度 num_heads: 注意力头数 num_encoder_layers: 编码器层数 num_decoder_layers: 解码器层数 dim_feedforward: 前馈网络维度 dropout: Dropout率 """ super(TemporalFusionTransformer, self).__init__() self.hidden_size = hidden_size self.num_observed_features = num_observed_features self.num_static_features = num_static_features self.num_known_future_features = num_known_future_features self.num_output_features = num_output_features or num_observed_features # 默认输出与观测特征相同 # 特征嵌入 self.observed_embedding = nn.Linear(num_observed_features, hidden_size) self.static_embedding = nn.Linear(num_static_features, hidden_size) self.known_future_embedding = nn.Linear(num_known_future_features, hidden_size) # 位置编码(可选) self.pos_encoder = PositionalEncoding(hidden_size, dropout) # 编码器 self.encoder = TFTEncoder( d_model=hidden_size, nhead=num_heads, num_layers=num_encoder_layers, dim_feedforward=dim_feedforward, dropout=dropout ) # 解码器 self.decoder = TFTDecoder( d_model=hidden_size, nhead=num_heads, num_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout ) # 输出层(预测原始特征) self.output_layer = nn.Linear(hidden_size, self.num_output_features) # 静态特征融合(将静态特征广播到每个时间步) self.static_fusion = nn.Linear(hidden_size * 2, hidden_size) def forward(self, observed_features, static_features, known_future_features, mask=None, debug=False): """ 前向传播 参数: observed_features: 观测时序特征 [batch_size, seq_len, num_observed_features] static_features: 静态特征 [batch_size, num_static_features] known_future_features: 已知未来特征 [batch_size, pred_len, num_known_future_features] mask: 注意力mask [batch_size, seq_len, num_features] 或 [batch_size, seq_len] 如果是3D:1表示有效,0表示缺失;如果某个时间步的所有特征都缺失,则该时间步被mask 如果是2D:True表示需要mask的位置(padding),False表示有效位置 debug: 是否启用调试日志 返回: predictions: 预测值 [batch_size, pred_len, num_observed_features] """ batch_size, seq_len, _ = observed_features.size() pred_len = known_future_features.size(1) if debug: print(f" [TFT] 输入检查:") print(f" observed_features: shape={observed_features.shape}, has_nan={torch.isnan(observed_features).any().item()}") print(f" static_features: shape={static_features.shape}, has_nan={torch.isnan(static_features).any().item()}") print(f" known_future_features: shape={known_future_features.shape}, has_nan={torch.isnan(known_future_features).any().item()}") # 嵌入观测特征 observed_emb = self.observed_embedding(observed_features) # [batch_size, seq_len, hidden_size] if debug and torch.isnan(observed_emb).any(): print(f" ⚠️ observed_emb包含NaN(在embedding后)") observed_emb = self.pos_encoder(observed_emb) if debug and torch.isnan(observed_emb).any(): print(f" ⚠️ observed_emb包含NaN(在pos_encoder后)") # 嵌入静态特征并广播 static_emb = self.static_embedding(static_features) # [batch_size, hidden_size] if debug and torch.isnan(static_emb).any(): print(f" ⚠️ static_emb包含NaN(在embedding后)") static_emb = static_emb.unsqueeze(1).expand(-1, seq_len, -1) # [batch_size, seq_len, hidden_size] # 融合静态特征和观测特征 encoder_input = torch.cat([observed_emb, static_emb], dim=-1) # [batch_size, seq_len, hidden_size*2] if debug and torch.isnan(encoder_input).any(): print(f" ⚠️ encoder_input包含NaN(在concat后)") encoder_input = self.static_fusion(encoder_input) # [batch_size, seq_len, hidden_size] if debug and torch.isnan(encoder_input).any(): print(f" ⚠️ encoder_input包含NaN(在static_fusion后)") # 将3D mask转换为2D mask(如果mask是3D的) # mask: [batch_size, seq_len, num_features] -> [batch_size, seq_len] # 如果某个时间步的所有特征都缺失,则该时间步被mask if mask is not None and mask.dim() == 3: # mask值为1表示有效,0表示缺失 # 如果某个时间步的所有特征都缺失(sum=0),则该时间步应该被mask(True) mask_2d = (mask.sum(dim=-1) == 0).bool() # [batch_size, seq_len] # True表示需要mask的位置(padding),False表示有效位置 # 检查:如果整个batch的所有序列都被mask,则设为None(避免Transformer错误) if mask_2d.all(): mask_2d = None # 整个batch都被mask,不使用mask elif mask_2d.all(dim=1).any(): # 如果某个样本的所有时间步都被mask,至少保留一个时间步不被mask(避免全mask) for i in range(mask_2d.size(0)): if mask_2d[i].all(): mask_2d[i, 0] = False # 至少保留第一个时间步 elif mask is not None and mask.dim() == 2: # 如果已经是2D的,直接使用(假设True表示需要mask) mask_2d = mask.bool() if mask.dtype != torch.bool else mask # 同样检查全mask的情况 if mask_2d.all(): mask_2d = None elif mask_2d.all(dim=1).any(): for i in range(mask_2d.size(0)): if mask_2d[i].all(): mask_2d[i, 0] = False else: mask_2d = None # 编码 encoder_output = self.encoder(encoder_input, mask=mask_2d) # [batch_size, seq_len, hidden_size] if debug and torch.isnan(encoder_output).any(): print(f" ⚠️ encoder_output包含NaN(在encoder后)") # 嵌入已知未来特征 known_future_emb = self.known_future_embedding(known_future_features) # [batch_size, pred_len, hidden_size] if debug and torch.isnan(known_future_emb).any(): print(f" ⚠️ known_future_emb包含NaN(在embedding后)") known_future_emb = self.pos_encoder(known_future_emb) if debug and torch.isnan(known_future_emb).any(): print(f" ⚠️ known_future_emb包含NaN(在pos_encoder后)") # 解码(使用相同的mask_2d) decoder_output = self.decoder( known_future_emb, encoder_output, memory_mask=mask_2d ) # [batch_size, pred_len, hidden_size] if debug and torch.isnan(decoder_output).any(): print(f" ⚠️ decoder_output包含NaN(在decoder后)") # 输出预测 predictions = self.output_layer(decoder_output) # [batch_size, pred_len, num_observed_features] if debug and torch.isnan(predictions).any(): print(f" ⚠️ predictions包含NaN(在output_layer后)") return predictions class PositionalEncoding(nn.Module): """位置编码""" def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1), :] return self.dropout(x)