Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import AutoTokenizer | |
| from transformers import pipeline | |
| import torch | |
| import os | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| from PIL import Image | |
| from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int | |
| config = { | |
| "model_name": "keras-io/multimodal-entailment", | |
| "base_model_name": "distilbert-base-uncased", | |
| "image_gen_model": "biggan-deep-512", | |
| "max_length": 20, | |
| "freeze_text_model": True, | |
| "freeze_image_gen_model": True, | |
| "text_embedding_dim": 768, | |
| "class_embedding_dim": 128 | |
| } | |
| truncation=0.4 | |
| is_gpu = False | |
| device = torch.device('cuda') if is_gpu else torch.device('cpu') | |
| print(device) | |
| model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get( | |
| 'huggingface-api-token')) | |
| tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"]) | |
| model.to(device) | |
| model.eval() | |
| gan_model = BigGAN.from_pretrained(config["image_gen_model"]) | |
| gan_model.to(device) | |
| gan_model.eval() | |
| print("Models were loaded") | |
| def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4): | |
| seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None | |
| noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) | |
| noise_vector = torch.from_numpy(noise_vector) | |
| if int_index is not None: | |
| class_vector = one_hot_from_int([int_index], batch_size=1) | |
| class_vector = torch.from_numpy(class_vector) | |
| dense_class_vector = gan_model.embeddings(class_vector) | |
| else: | |
| if isinstance(dense_class_vector, np.ndarray): | |
| dense_class_vector = torch.tensor(dense_class_vector) | |
| dense_class_vector = dense_class_vector.view(1, 128) | |
| input_vector = torch.cat([noise_vector, dense_class_vector], dim=1) | |
| # Generate an image | |
| with torch.no_grad(): | |
| output = gan_model.generator(input_vector, truncation) | |
| output = output.cpu().numpy() | |
| output = output.transpose((0, 2, 3, 1)) | |
| output = ((output + 1.0) / 2.0) * 256 | |
| output.clip(0, 255, out=output) | |
| output = np.asarray(np.uint8(output[0]), dtype=np.uint8) | |
| return output | |
| def print_image(numpy_array): | |
| """ Utility function to print a numpy uint8 array as an image | |
| """ | |
| img = Image.fromarray(numpy_array) | |
| plt.imshow(img) | |
| plt.show() | |
| def text_to_image(text): | |
| tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| lm_output = model(tokens, return_dict=True) | |
| pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist() | |
| print(pred_int_index) | |
| # Now generate an image (a numpy array) | |
| numpy_image = generate_image(int_index=pred_int_index, | |
| truncation=truncation, | |
| noise_seed_vector=tokens) | |
| img = Image.fromarray(numpy_image) | |
| #print_image(numpy_image) | |
| return img | |
| examples = ["a high resoltuion photo of a pizza from famous food magzine.", | |
| "this is a photo of my pet golden retriever.", | |
| "this is a photo of a trouble some street cat.", | |
| "a blur image of coral reef.", | |
| "a yellow taxi cab commonly found in USA.", | |
| "Once upon a time, there was a black ship full of pirates.", | |
| "a photo of a large castle.", | |
| "a sketch of an old Church"] | |
| if __name__ == '__main__': | |
| interFace = gr.Interface(fn=text_to_image, | |
| inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text " | |
| "query", | |
| lines=1), | |
| outputs=gr.outputs.Image(type="auto", label="Generated Image"), | |
| verbose=True, | |
| examples=examples, | |
| title="Generate Image from Text", | |
| description="", | |
| theme="huggingface") | |
| interFace.launch() | |