Spaces:
Runtime error
Runtime error
Commit
·
ef6f6bd
1
Parent(s):
d83af99
random calibration & they"re kept in the df
Browse files
app.py
CHANGED
|
@@ -102,11 +102,17 @@ def sample_embs(prompt_embeds):
|
|
| 102 |
def get_user_emb(embs, ys):
|
| 103 |
positives = [e for e, ys in zip(embs, ys) if ys == 1]
|
| 104 |
embs = random.sample(positives, min(8, len(positives)))
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
| 108 |
negative_embs = random.sample(negs, min(8, len(negs)))
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
|
| 112 |
|
|
@@ -202,10 +208,12 @@ def pluck_img(user_id):
|
|
| 202 |
|
| 203 |
def next_image(calibrate_prompts, user_id):
|
| 204 |
with torch.no_grad():
|
| 205 |
-
|
| 206 |
-
|
|
|
|
| 207 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 208 |
return image, calibrate_prompts
|
|
|
|
| 209 |
else:
|
| 210 |
image = pluck_img(user_id)
|
| 211 |
return image, calibrate_prompts
|
|
@@ -330,7 +338,7 @@ Explore the latent space without text prompts based on your preferences. [rynmur
|
|
| 330 |
''', elem_id="description")
|
| 331 |
user_id = gr.State()
|
| 332 |
# calibration videos -- this is a misnomer now :D
|
| 333 |
-
calibrate_prompts = gr.State(
|
| 334 |
def l():
|
| 335 |
return None
|
| 336 |
|
|
@@ -428,8 +436,8 @@ for im in m_calibrate:
|
|
| 428 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
| 429 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 430 |
tmp_df['text'] = ['']
|
| 431 |
-
|
| 432 |
-
|
| 433 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 434 |
|
| 435 |
glob_idx = 0
|
|
|
|
| 102 |
def get_user_emb(embs, ys):
|
| 103 |
positives = [e for e, ys in zip(embs, ys) if ys == 1]
|
| 104 |
embs = random.sample(positives, min(8, len(positives)))
|
| 105 |
+
if len(embs) == 0:
|
| 106 |
+
positives = torch.zeros_like(im_emb)[None]
|
| 107 |
+
else:
|
| 108 |
+
positives = torch.stack(embs, 1)
|
| 109 |
|
| 110 |
negs = [e for e, ys in zip(embs, ys) if ys == 0]
|
| 111 |
negative_embs = random.sample(negs, min(8, len(negs)))
|
| 112 |
+
if len(negative_embs) == 0:
|
| 113 |
+
negatives = torch.zeros_like(im_emb)[None]
|
| 114 |
+
else:
|
| 115 |
+
negatives = torch.stack(negative_embs, 1)
|
| 116 |
|
| 117 |
image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)])
|
| 118 |
|
|
|
|
| 208 |
|
| 209 |
def next_image(calibrate_prompts, user_id):
|
| 210 |
with torch.no_grad():
|
| 211 |
+
# once we've done so many random calibration prompts out of the full media
|
| 212 |
+
if len(m_calibrate) - len(calibrate_prompts) < 5:
|
| 213 |
+
cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1))
|
| 214 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 215 |
return image, calibrate_prompts
|
| 216 |
+
# we switch to just getting media by similarity.
|
| 217 |
else:
|
| 218 |
image = pluck_img(user_id)
|
| 219 |
return image, calibrate_prompts
|
|
|
|
| 338 |
''', elem_id="description")
|
| 339 |
user_id = gr.State()
|
| 340 |
# calibration videos -- this is a misnomer now :D
|
| 341 |
+
calibrate_prompts = gr.State( glob.glob('image_init/*') )
|
| 342 |
def l():
|
| 343 |
return None
|
| 344 |
|
|
|
|
| 436 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
| 437 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 438 |
tmp_df['text'] = ['']
|
| 439 |
+
tmp_df['from_user_id'] = [0]
|
| 440 |
+
tmp_df['latest_user_to_rate'] = [0]
|
| 441 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 442 |
|
| 443 |
glob_idx = 0
|