File size: 1,406 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from dataclasses import dataclass
import torch
@dataclass
class ModelConfig:
# sequence latent space config
encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base" # or roberta-base
input_dim: int = 768 # Jina Base is 768
latent_dim: int = 768 # 保留最大语义
decoder_layers: int = 4 # simple NAR decoder
# VAE Adapter config
max_seq_len: int = 2048 # set according to task
patch_size: int = 4 # patching compress rate
# DiT setting
dit_layers: int = 12
dit_heads: int = 8
dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom
mlp_ratio: float = 4.0
# @property
# def dit_hidden(self):
# return self.latent_dim
@dataclass
class TrainConfig:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
lr_ae: float = 1e-4
lr_flow: float = 5e-4
batch_size: int = 8
grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32
num_epochs_ae: int = 20 # 先训练AE 再训练Flow
num_epochs_flow: int = 50 # flow 需要训练的论数要多一些
grad_clip: float = 1.0
use_amp: bool = False # 混合精度训练,Jina+AMP 容易报错
save_dir: str = "./checkpoints"
def __post_init__(self):
import os
os.makedirs(self.save_dir, exist_ok=True) |