Spaces:
Runtime error
Runtime error
| import jax | |
| import jax.numpy as jnp | |
| from flax import jax_utils | |
| from flax.training.common_utils import shard | |
| from PIL import Image | |
| from argparse import Namespace | |
| import gradio as gr | |
| import numpy as np | |
| import mediapipe as mp | |
| from mediapipe import solutions | |
| from mediapipe.framework.formats import landmark_pb2 | |
| from mediapipe.tasks import python | |
| from mediapipe.tasks.python import vision | |
| import cv2 | |
| from diffusers import ( | |
| FlaxControlNetModel, | |
| FlaxStableDiffusionControlNetPipeline, | |
| ) | |
| # mediapipe annotation | |
| MARGIN = 10 # pixels | |
| FONT_SIZE = 1 | |
| FONT_THICKNESS = 1 | |
| HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green | |
| def draw_landmarks_on_image(rgb_image, detection_result): | |
| hand_landmarks_list = detection_result.hand_landmarks | |
| handedness_list = detection_result.handedness | |
| annotated_image = np.zeros_like(rgb_image) | |
| # Loop through the detected hands to visualize. | |
| for idx in range(len(hand_landmarks_list)): | |
| hand_landmarks = hand_landmarks_list[idx] | |
| handedness = handedness_list[idx] | |
| # Draw the hand landmarks. | |
| hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | |
| hand_landmarks_proto.landmark.extend([ | |
| landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks | |
| ]) | |
| solutions.drawing_utils.draw_landmarks( | |
| annotated_image, | |
| hand_landmarks_proto, | |
| solutions.hands.HAND_CONNECTIONS, | |
| solutions.drawing_styles.get_default_hand_landmarks_style(), | |
| solutions.drawing_styles.get_default_hand_connections_style()) | |
| return annotated_image | |
| def generate_annotation(img): | |
| """img(input): numpy array | |
| annotated_image(output): numpy array | |
| """ | |
| # STEP 2: Create an HandLandmarker object. | |
| base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') | |
| options = vision.HandLandmarkerOptions(base_options=base_options, | |
| num_hands=2) | |
| detector = vision.HandLandmarker.create_from_options(options) | |
| # STEP 3: Load the input image. | |
| image = mp.Image( | |
| image_format=mp.ImageFormat.SRGB, data=img) | |
| # STEP 4: Detect hand landmarks from the input image. | |
| detection_result = detector.detect(image) | |
| # STEP 5: Process the classification result. In this case, visualize it. | |
| annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result) | |
| return annotated_image | |
| args = Namespace( | |
| pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
| revision="non-ema", | |
| from_pt=True, | |
| controlnet_model_name_or_path="Vincent-luo/controlnet-hands", | |
| controlnet_revision=None, | |
| controlnet_from_pt=False, | |
| ) | |
| weight_dtype = jnp.float32 | |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
| args.controlnet_model_name_or_path, | |
| revision=args.controlnet_revision, | |
| from_pt=args.controlnet_from_pt, | |
| dtype=jnp.float32, | |
| ) | |
| pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| # tokenizer=tokenizer, | |
| controlnet=controlnet, | |
| safety_checker=None, | |
| dtype=weight_dtype, | |
| revision=args.revision, | |
| from_pt=args.from_pt, | |
| ) | |
| pipeline_params["controlnet"] = controlnet_params | |
| pipeline_params = jax_utils.replicate(pipeline_params) | |
| rng = jax.random.PRNGKey(0) | |
| num_samples = jax.device_count() | |
| prng_seed = jax.random.split(rng, jax.device_count()) | |
| def infer(prompt, negative_prompt, image): | |
| prompts = num_samples * [prompt] | |
| prompt_ids = pipeline.prepare_text_inputs(prompts) | |
| prompt_ids = shard(prompt_ids) | |
| annotated_image = generate_annotation(image) | |
| validation_image = Image.fromarray(annotated_image).convert("RGB") | |
| processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) | |
| processed_image = shard(processed_image) | |
| negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| images = pipeline( | |
| prompt_ids=prompt_ids, | |
| image=processed_image, | |
| params=pipeline_params, | |
| prng_seed=prng_seed, | |
| num_inference_steps=50, | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | |
| results = [i for i in images] | |
| return [annotated_image] + results | |
| with gr.Blocks(theme='gradio/soft') as demo: | |
| gr.Markdown("## Stable Diffusion with Hand Control") | |
| gr.Markdown("In this app, you can find different ControlNets with different filters. ") | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt") | |
| negative_prompt = gr.Textbox(label="Negative Prompt") | |
| input_image = gr.Image(label="Input Image") | |
| output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') | |
| submit_btn = gr.Button(value = "Submit") | |
| inputs = [prompt_input, negative_prompt, input_image] | |
| submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) | |
| demo.launch() |