from dataclasses import dataclass import torch @dataclass class ModelConfig: # sequence latent space config encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base" # or roberta-base input_dim: int = 768 # Jina Base is 768 latent_dim: int = 768 # 保留最大语义 decoder_layers: int = 4 # simple NAR decoder # VAE Adapter config max_seq_len: int = 2048 # set according to task patch_size: int = 4 # patching compress rate # DiT setting dit_layers: int = 12 dit_heads: int = 8 dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom mlp_ratio: float = 4.0 # @property # def dit_hidden(self): # return self.latent_dim @dataclass class TrainConfig: device: str = "cuda" if torch.cuda.is_available() else "cpu" lr_ae: float = 1e-4 lr_flow: float = 5e-4 batch_size: int = 8 grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32 num_epochs_ae: int = 20 # 先训练AE 再训练Flow num_epochs_flow: int = 50 # flow 需要训练的论数要多一些 grad_clip: float = 1.0 use_amp: bool = False # 混合精度训练,Jina+AMP 容易报错 save_dir: str = "./checkpoints" def __post_init__(self): import os os.makedirs(self.save_dir, exist_ok=True)