| from dataclasses import dataclass | |
| import torch | |
| class ModelConfig: | |
| # sequence latent space config | |
| encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base" # or roberta-base | |
| input_dim: int = 768 # Jina Base is 768 | |
| latent_dim: int = 768 # 保留最大语义 | |
| decoder_layers: int = 4 # simple NAR decoder | |
| # VAE Adapter config | |
| max_seq_len: int = 2048 # set according to task | |
| patch_size: int = 4 # patching compress rate | |
| # DiT setting | |
| dit_layers: int = 12 | |
| dit_heads: int = 8 | |
| dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom | |
| mlp_ratio: float = 4.0 | |
| # @property | |
| # def dit_hidden(self): | |
| # return self.latent_dim | |
| class TrainConfig: | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| lr_ae: float = 1e-4 | |
| lr_flow: float = 5e-4 | |
| batch_size: int = 8 | |
| grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32 | |
| num_epochs_ae: int = 20 # 先训练AE 再训练Flow | |
| num_epochs_flow: int = 50 # flow 需要训练的论数要多一些 | |
| grad_clip: float = 1.0 | |
| use_amp: bool = False # 混合精度训练,Jina+AMP 容易报错 | |
| save_dir: str = "./checkpoints" | |
| def __post_init__(self): | |
| import os | |
| os.makedirs(self.save_dir, exist_ok=True) |