Text Generation
Transformers
PyTorch
TensorFlow
JAX
TensorBoard
English
t5
text2text-generation
seq2seq
recipe-generation
text-generation-inference
Instructions to use jejun/flax-recipe-generator with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use jejun/flax-recipe-generator with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="jejun/flax-recipe-generator")# Load model directly from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tokenizer = AutoTokenizer.from_pretrained("jejun/flax-recipe-generator") model = AutoModelForSeq2SeqLM.from_pretrained("jejun/flax-recipe-generator") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use jejun/flax-recipe-generator with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "jejun/flax-recipe-generator" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "jejun/flax-recipe-generator", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/jejun/flax-recipe-generator
- SGLang
How to use jejun/flax-recipe-generator with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "jejun/flax-recipe-generator" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "jejun/flax-recipe-generator", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "jejun/flax-recipe-generator" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "jejun/flax-recipe-generator", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use jejun/flax-recipe-generator with Docker Model Runner:
docker model run hf.co/jejun/flax-recipe-generator
| import logging | |
| import os | |
| import pandas as pd | |
| import random | |
| import re | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import Callable, Optional | |
| import jax | |
| import jax.numpy as jnp | |
| from filelock import FileLock | |
| from flax import jax_utils, traverse_util | |
| from flax.jax_utils import unreplicate | |
| from flax.training import train_state | |
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key | |
| from transformers import FlaxAutoModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset, load_dataset, load_metric | |
| from tqdm import tqdm | |
| import pandas as pd | |
| print(jax.devices()) | |
| MODEL_NAME_OR_PATH = "../" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | |
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) | |
| prefix = "items: " | |
| text_column = "inputs" | |
| target_column = "targets" | |
| max_source_length = 256 | |
| max_target_length = 1024 | |
| seed = 42 | |
| eval_batch_size = 64 | |
| # generation_kwargs = { | |
| # "max_length": 1024, | |
| # "min_length": 128, | |
| # "no_repeat_ngram_size": 3, | |
| # "do_sample": True, | |
| # "top_k": 60, | |
| # "top_p": 0.95 | |
| # } | |
| generation_kwargs = { | |
| "max_length": 1024, | |
| "min_length": 64, | |
| "no_repeat_ngram_size": 3, | |
| "early_stopping": True, | |
| "num_beams": 4, | |
| "length_penalty": 1.5, | |
| } | |
| special_tokens = tokenizer.all_special_tokens | |
| tokens_map = { | |
| "<sep>": "--", | |
| "<section>": "\n" | |
| } | |
| def skip_special_tokens(text, special_tokens): | |
| for token in special_tokens: | |
| text = text.replace(token, '') | |
| return text | |
| def target_postprocessing(texts, special_tokens): | |
| if not isinstance(texts, list): | |
| texts = [texts] | |
| new_texts = [] | |
| for text in texts: | |
| text = skip_special_tokens(text, special_tokens) | |
| for k, v in tokens_map.items(): | |
| text = text.replace(k, v) | |
| new_texts.append(text) | |
| return new_texts | |
| predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"] | |
| print(predict_dataset) | |
| # predict_dataset = predict_dataset.select(range(10)) | |
| # print(predict_dataset) | |
| column_names = predict_dataset.column_names | |
| print(column_names) | |
| # Setting padding="max_length" as we need fixed length inputs for jitted functions | |
| def preprocess_function(examples): | |
| inputs = examples[text_column] | |
| targets = examples[target_column] | |
| inputs = [prefix + inp for inp in inputs] | |
| model_inputs = tokenizer( | |
| inputs, | |
| max_length=max_source_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="np" | |
| ) | |
| # Setup the tokenizer for targets | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer( | |
| targets, | |
| max_length=max_target_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="np" | |
| ) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| predict_dataset = predict_dataset.map( | |
| preprocess_function, | |
| batched=True, | |
| num_proc=None, | |
| remove_columns=column_names, | |
| desc="Running tokenizer on prediction dataset", | |
| ) | |
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): | |
| """ | |
| Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. | |
| Shuffle batches if `shuffle` is `True`. | |
| """ | |
| steps_per_epoch = len(dataset) // batch_size | |
| if shuffle: | |
| batch_idx = jax.random.permutation(rng, len(dataset)) | |
| else: | |
| batch_idx = jnp.arange(len(dataset)) | |
| batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. | |
| batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | |
| for idx in batch_idx: | |
| batch = dataset[idx] | |
| batch = {k: jnp.array(v) for k, v in batch.items()} | |
| batch = shard(batch) | |
| yield batch | |
| rng = jax.random.PRNGKey(seed) | |
| rng, dropout_rng = jax.random.split(rng) | |
| rng, input_rng = jax.random.split(rng) | |
| def generate_step(batch): | |
| output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs) | |
| return output_ids.sequences | |
| p_generate_step = jax.pmap(generate_step, "batch") | |
| pred_generations = [] | |
| pred_labels = [] | |
| pred_inputs = [] | |
| pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) | |
| pred_steps = len(predict_dataset) // eval_batch_size | |
| for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): | |
| # Model forward | |
| batch = next(pred_loader) | |
| inputs = batch["input_ids"] | |
| labels = batch["labels"] | |
| generated_ids = p_generate_step(batch) | |
| pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"]))) | |
| pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) | |
| pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1]))) | |
| inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True) | |
| true_recipe = target_postprocessing( | |
| tokenizer.batch_decode(pred_labels, skip_special_tokens=False), | |
| special_tokens | |
| ) | |
| generated_recipe = target_postprocessing( | |
| tokenizer.batch_decode(pred_generations, skip_special_tokens=False), | |
| special_tokens | |
| ) | |
| test_output = { | |
| "inputs": inputs, | |
| "true_recipe": true_recipe, | |
| "generated_recipe": generated_recipe | |
| } | |
| test_output = pd.DataFrame.from_dict(test_output) | |
| test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8") | |