File size: 1,406 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass
import torch

@dataclass
class ModelConfig:
    # sequence latent space config
    encoder_name: str = "../jina-embeddings-v2-base-code" #"jinaai/jina-embeddings-v2-base-code","microsoft/codebert-base"  # or roberta-base
    input_dim: int = 768 # Jina Base is 768
    latent_dim: int = 768 # 保留最大语义
    decoder_layers: int = 4 # simple NAR decoder
    
    # VAE Adapter config
    max_seq_len: int = 2048 # set according to task
    patch_size: int = 4 # patching compress rate
    
    # DiT setting 
    dit_layers: int = 12
    dit_heads: int = 8
    dit_hidden: int = 768 # hidden width, less than latent_dim*patch_size to cut oom
    mlp_ratio: float = 4.0

    # @property
    # def dit_hidden(self):
    #     return self.latent_dim

@dataclass
class TrainConfig:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    lr_ae: float = 1e-4
    lr_flow: float = 5e-4
    batch_size: int = 8
    grad_accum_steps: int = 4 # 梯度积累,等效于Batch_size = 32
    
    num_epochs_ae: int = 20 # 先训练AE 再训练Flow
    num_epochs_flow: int = 50 # flow 需要训练的论数要多一些
    
    grad_clip: float = 1.0
    use_amp: bool = False  # 混合精度训练,Jina+AMP 容易报错
    save_dir: str = "./checkpoints"
    
    def __post_init__(self):
        import os
        os.makedirs(self.save_dir, exist_ok=True)