况兑 commited on
Commit
a769d64
·
1 Parent(s): a088db3

fix: cast attention_mask to bool to satisfy Dream forward/generate expectations

Browse files
Files changed (1) hide show
  1. loss_probe.py +13 -23
loss_probe.py CHANGED
@@ -5,33 +5,27 @@ MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B")
5
  REV = os.getenv("REV", None)
6
 
7
  print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
8
- print("[INFO] Loading tokenizer...")
9
  tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
10
 
11
- print("[INFO] Loading model...")
12
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
- model = AutoModel.from_pretrained(
14
- MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV
15
- )
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
- model = model.to(device).eval()
18
 
19
  def check_loss():
20
  msgs = [
21
  {"role": "system", "content": "只输出一个数字"},
22
  {"role": "user", "content": "Compute: 1+1"},
23
  ]
24
- enc = tok.apply_chat_template(
25
- msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False
26
- )
27
- enc = {k: v.to(device) for k, v in enc.items()}
28
- labels = enc["input_ids"].clone()
 
 
 
29
  try:
30
- out = model(
31
- input_ids=enc["input_ids"],
32
- attention_mask=enc.get("attention_mask"),
33
- labels=labels,
34
- )
35
  has_loss = getattr(out, "loss", None) is not None
36
  return f"[CHECK] supports labels->loss? {has_loss} | type={type(out)}"
37
  except Exception as e:
@@ -41,11 +35,9 @@ def quick_infer(q: str):
41
  if not q.strip():
42
  return ""
43
  messages = [{"role": "user", "content": q}]
44
- inputs = tok.apply_chat_template(
45
- messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
46
- )
47
  input_ids = inputs.input_ids.to(device)
48
- attention_mask = inputs.attention_mask.to(device)
49
  with torch.no_grad():
50
  out = model.diffusion_generate(
51
  input_ids,
@@ -55,9 +47,7 @@ def quick_infer(q: str):
55
  temperature=0.0,
56
  return_dict_in_generate=True,
57
  )
58
- text = tok.decode(
59
- out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True
60
- ).strip()
61
  return text
62
 
63
  with gr.Blocks() as demo:
 
5
  REV = os.getenv("REV", None)
6
 
7
  print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
 
8
  tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
9
 
 
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV).to(device).eval()
13
 
14
  def check_loss():
15
  msgs = [
16
  {"role": "system", "content": "只输出一个数字"},
17
  {"role": "user", "content": "Compute: 1+1"},
18
  ]
19
+ enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False)
20
+ # 保证 dtype / device 正确;attention_mask 用 bool 可兼容
21
+ input_ids = enc["input_ids"].to(device)
22
+ attn = enc.get("attention_mask", None)
23
+ if attn is not None:
24
+ attn = attn.to(device).to(torch.bool)
25
+ labels = input_ids.clone()
26
+
27
  try:
28
+ out = model(input_ids=input_ids, attention_mask=attn, labels=labels)
 
 
 
 
29
  has_loss = getattr(out, "loss", None) is not None
30
  return f"[CHECK] supports labels->loss? {has_loss} | type={type(out)}"
31
  except Exception as e:
 
35
  if not q.strip():
36
  return ""
37
  messages = [{"role": "user", "content": q}]
38
+ inputs = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
 
 
39
  input_ids = inputs.input_ids.to(device)
40
+ attention_mask = inputs.attention_mask.to(device).to(torch.bool)
41
  with torch.no_grad():
42
  out = model.diffusion_generate(
43
  input_ids,
 
47
  temperature=0.0,
48
  return_dict_in_generate=True,
49
  )
50
+ text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
 
 
51
  return text
52
 
53
  with gr.Blocks() as demo: