Instructions to use studentscolab/iris_keras with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use studentscolab/iris_keras with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://studentscolab/iris_keras") - Notebooks
- Google Colab
- Kaggle
| language: en | |
| library_name: tensorflow | |
| tags: | |
| - keras | |
| - tensorflow | |
| - tabular | |
| - iris | |
| - multiclass-classification | |
| pipeline_tag: tabular-classification | |
| license: mit | |
| # Iris MLP Classifier (Keras / TensorFlow) | |
| This repository contains a simple **multiclass classifier** for the classic **Iris** dataset, implemented as a small **MLP (Multi-Layer Perceptron)** in **TensorFlow / Keras**. | |
| The model predicts one of three classes based on four numerical features. | |
| ## Task | |
| **Tabular multiclass classification** | |
| Given the 4 iris measurements, predict the class: | |
| - `0` → setosa | |
| - `1` → versicolor | |
| - `2` → virginica | |
| ## Dataset | |
| **Iris dataset** (from `sklearn.datasets.load_iris`) | |
| ### Input features (4) | |
| The model expects **4 float features** in this exact order: | |
| 1. `sepal length (cm)` | |
| 2. `sepal width (cm)` | |
| 3. `petal length (cm)` | |
| 4. `petal width (cm)` | |
| ### Target (3 classes) | |
| - integer labels `y ∈ {0,1,2}` | |
| - no one-hot encoding required (training uses `sparse_categorical_crossentropy`) | |
| --- | |
| ## Model architecture | |
| A small feed-forward network with built-in feature normalization: | |
| ### Why `Normalization` layer? | |
| The `tf.keras.layers.Normalization` layer learns feature-wise mean and variance from the training set via `adapt(...)`. | |
| This makes inference easier and safer: **the same scaling used during training is embedded inside the saved model**. | |
| --- | |
| ## Training configuration | |
| - **Optimizer:** Adam (`learning_rate=1e-3`) | |
| - **Loss:** `sparse_categorical_crossentropy` | |
| - **Metric:** accuracy | |
| - **Train/test split:** 80/20 (`stratify=y`, `random_state=42`) | |
| - **Validation split (from train):** 20% (`validation_split=0.2`) | |
| - **Epochs:** 100 | |
| - **Batch size:** 16 | |
| - **Reproducibility:** | |
| - `tf.random.set_seed(42)` | |
| - `np.random.seed(42)` | |
| > Note: Exact accuracy may vary slightly across environments due to numerical differences and nondeterminism in some TF ops. | |
| --- | |
| ## Example: training script (reference) | |
| The model was trained using the following core logic (simplified): | |
| ```python | |
| normalizer = tf.keras.layers.Normalization(axis=-1) | |
| normalizer.adapt(X_train.to_numpy()) | |
| model = tf.keras.Sequential([ | |
| tf.keras.Input(shape=(4,)), | |
| normalizer, | |
| tf.keras.layers.Dense(16, activation="relu"), | |
| tf.keras.layers.Dense(16, activation="relu"), | |
| tf.keras.layers.Dense(3, activation="softmax"), | |
| ]) | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), | |
| loss="sparse_categorical_crossentropy", | |
| metrics=["accuracy"] | |
| ) | |
| model.fit( | |
| X_train.to_numpy(), y_train.to_numpy(), | |
| validation_split=0.2, | |
| epochs=100, | |
| batch_size=16, | |
| verbose=0 | |
| ) | |
| ``` | |
| ## Example | |
| ```python | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| filename = "iris_mlp.keras" | |
| repo_id = "studentscolab/iris_keras" | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| model = tf.keras.models.load_model(model_path) | |
| x_new = pd.DataFrame([{ | |
| "sepal length (cm)": 5.1, | |
| "sepal width (cm)": 3.5, | |
| "petal length (cm)": 1.4, | |
| "petal width (cm)": 0.2, | |
| }]) | |
| proba = model.predict(x_new.to_numpy(), verbose=0)[0] # shape: (3,) | |
| pred = int(np.argmax(proba)) | |
| print("Probabilities:", proba) | |
| print("Predicted class:", pred) | |
| ``` |