FlameF0X/NanoSR
Viewer • Updated • 1.6k • 1.46k • 1
Introducing NanoSR, a very small 6x upscaler using PixelShuffle and NN.
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from huggingface_hub import hf_hub_download
from tensorflow.keras import layers
# --- Configuration ---
HF_REPO_ID = "FlameF0X/NanoSR-6x"
MODEL_FILENAME = "NanoSR-6x_v1.h5"
UPSCALE_FACTOR = 6
def build_nanosr(upscale_factor, channels=3):
inputs = layers.Input(shape=(None, None, channels))
x = layers.Conv2D(64, 5, padding="same", activation="relu", kernel_initializer="he_normal")(inputs)
x = layers.Conv2D(64, 3, padding="same", activation="relu", kernel_initializer="he_normal")(x)
x = layers.Conv2D(32, 3, padding="same", activation="relu", kernel_initializer="he_normal")(x)
x = layers.Conv2D(channels * (upscale_factor ** 2), 3, padding="same", kernel_initializer="he_normal")(x)
outputs = layers.Lambda(lambda t: tf.nn.depth_to_space(t, upscale_factor), name="pixel_shuffle")(x)
return tf.keras.Model(inputs, outputs, name="NanoSR-6x")
def run_inference(image_path, output_path="upscaled_result.png"):
print(f"Downloading weights from {HF_REPO_ID}...")
checkpoint_path = hf_hub_download(repo_id=HF_REPO_ID, filename=MODEL_FILENAME)
# 2. Build and Load Model
model = build_nanosr(upscale_factor=UPSCALE_FACTOR)
model.load_weights(checkpoint_path)
print("Model loaded successfully.")
# 3. Load and Preprocess Image
img = Image.open(image_path).convert("RGB")
low_res = np.array(img).astype(np.float32) / 255.0
input_tensor = tf.expand_dims(low_res, 0)
# 4. Predict
print("Upscaling...")
prediction = model.predict(input_tensor)[0]
prediction = tf.clip_by_value(prediction, 0.0, 1.0)
# 5. Save and Show
final_img = tf.keras.preprocessing.image.array_to_img(prediction)
final_img.save(output_path)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Original (Low Res)")
plt.imshow(low_res)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title(f"NanoSR {UPSCALE_FACTOR}x")
plt.imshow(prediction)
plt.axis("off")
plt.show()
print(f"Result saved to {output_path}")
if __name__ == "__main__":
run_inference("your_test_image.png")
pass