Update README.md
Browse files
README.md
CHANGED
|
@@ -62,35 +62,46 @@ import opensportslib
|
|
| 62 |
print("OpenSportsLib imported successfully")
|
| 63 |
```
|
| 64 |
|
| 65 |
-
### Train a
|
| 66 |
|
| 67 |
```python
|
| 68 |
-
from opensportslib import
|
| 69 |
|
| 70 |
-
|
| 71 |
-
config="/path/to/localization.yaml"
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
train_set="/path/to/
|
| 76 |
-
valid_set="/path/to/
|
| 77 |
-
👉 pretrained="OpenSportsLab/oslib-e2e-localization-snbas-2023", # optional
|
| 78 |
)
|
|
|
|
|
|
|
| 79 |
```
|
| 80 |
|
| 81 |
### Run inference
|
| 82 |
|
| 83 |
```python
|
| 84 |
-
from opensportslib import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
-
metrics =
|
| 91 |
-
test_set="/path/to/
|
| 92 |
-
|
| 93 |
-
predictions="/path/to/predictions.json"
|
| 94 |
)
|
| 95 |
|
| 96 |
print(metrics)
|
|
|
|
| 62 |
print("OpenSportsLib imported successfully")
|
| 63 |
```
|
| 64 |
|
| 65 |
+
### Train a Localization model
|
| 66 |
|
| 67 |
```python
|
| 68 |
+
from opensportslib.apis import LocalizationModel
|
| 69 |
|
| 70 |
+
my_model = LocalizationModel(
|
| 71 |
+
config="/path/to/localization.yaml",
|
| 72 |
+
👉 weights="OpenSportsLab/OSL-loc-snbas-2023-e2e", # optional
|
| 73 |
)
|
| 74 |
|
| 75 |
+
best_checkpoint = my_model.train(
|
| 76 |
+
train_set="/path/to/train.json",
|
| 77 |
+
valid_set="/path/to/valid.json",
|
|
|
|
| 78 |
)
|
| 79 |
+
|
| 80 |
+
print(best_checkpoint)
|
| 81 |
```
|
| 82 |
|
| 83 |
### Run inference
|
| 84 |
|
| 85 |
```python
|
| 86 |
+
from opensportslib.apis import LocalizationModel
|
| 87 |
+
|
| 88 |
+
my_model = LocalizationModel(
|
| 89 |
+
config="/path/to/localization.yaml",
|
| 90 |
+
weights="OpenSportsLab/OSL-loc-snbas-2023-e2e",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
predictions = my_model.infer(
|
| 94 |
+
test_set="/path/to/test.json",
|
| 95 |
+
)
|
| 96 |
|
| 97 |
+
saved_predictions = my_model.save_predictions(
|
| 98 |
+
output_path="/path/to/predictions.json",
|
| 99 |
+
predictions=predictions,
|
| 100 |
)
|
| 101 |
|
| 102 |
+
metrics = my_model.evaluate(
|
| 103 |
+
test_set="/path/to/test.json",
|
| 104 |
+
predictions=saved_predictions,
|
|
|
|
| 105 |
)
|
| 106 |
|
| 107 |
print(metrics)
|