Fast Neural Style Transfer β€” LiteRT (on-device, fully-GPU, 4 styles)

Fast neural style transfer (PyTorch examples TransformerNet, Johnson et al.), converted to LiteRT and running fully on the CompiledModel GPU (ML Drift) on Android. Applies an artistic style to a photo β€” 4 styles (candy / mosaic / rain_princess / udnie), each a 3.5 MB fp16 graph.

Fast Neural Style β€” content + candy / mosaic / rain / udnie (on-device LiteRT GPU)

On-device (Pixel 8a, Tensor G3 β€” verified)

nodes on GPU 350 / 350 LITERT_CL (full residency)
inference ~9 ms (256Γ—256)
size 3.5 MB per style (fp16)
accuracy device-vs-PyTorch corr 0.9998–0.9999 (all 4 styles)
image[1,3,256,256] (RGB 0-255) β†’[GPU: TransformerNet]β†’ stylized[1,3,256,256] (RGB 0-255)

How it converts (litert-torch) β€” three numerically-exact re-authorings

  1. ReflectionPad2d β†’ zero-pad (GATHER_ND β†’ PAD; border-only difference).
  2. Large conv activations β†’ conv-weight scaling. The conv outputs reach β‰ˆ |5000|, where the Mali delegate's fp16 conv accumulation loses precision β†’ garbage (device corr 0.34 at full residency β€” residency β‰  correctness). Each conv is followed by an InstanceNorm (which is scale-invariant), so scaling those conv weights down so the output is β‰ˆ |10| is exact (IN output unchanged) and keeps the fp16 accumulation precise β†’ corr 1.0.
  3. InstanceNorm β†’ SafeInstanceNorm (down-scaled-domain spatial reduction, fp16-safe; SafeLayerNorm class).

Upsample is interpolate(nearest) (no transposed conv β†’ no ZeroStuff). Result: banned ops NONE, ≀4D, tflite-vs-torch corr 1.0, device-vs-torch corr 0.9999.

Preprocessing

Center-crop to square, resize to 256Γ—256, RGB 0–255 (no normalization), NCHW. Output is 0–255 RGB (clamp).

License

BSD-3-Clause. Upstream: pytorch/examples.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support