jeetv commited on
Commit
fdf57c0
·
verified ·
1 Parent(s): 3690479

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -15
README.md CHANGED
@@ -62,35 +62,46 @@ import opensportslib
62
  print("OpenSportsLib imported successfully")
63
  ```
64
 
65
- ### Train a classification model
66
 
67
  ```python
68
- from opensportslib import model
69
 
70
- myModel = model.localization(
71
- config="/path/to/localization.yaml"
 
72
  )
73
 
74
- myModel.train(
75
- train_set="/path/to/train_annotations.json",
76
- valid_set="/path/to/valid_annotations.json",
77
- 👉 pretrained="OpenSportsLab/oslib-e2e-localization-snbas-2023", # optional
78
  )
 
 
79
  ```
80
 
81
  ### Run inference
82
 
83
  ```python
84
- from opensportslib import model
 
 
 
 
 
 
 
 
 
85
 
86
- myModel = model.localization(
87
- config="/path/to/localization.yaml"
 
88
  )
89
 
90
- metrics = myModel.infer(
91
- test_set="/path/to/test_annotations.json",
92
- 👉 pretrained="OpenSportsLab/oslib-e2e-localization-snbas-2023",
93
- predictions="/path/to/predictions.json"
94
  )
95
 
96
  print(metrics)
 
62
  print("OpenSportsLib imported successfully")
63
  ```
64
 
65
+ ### Train a Localization model
66
 
67
  ```python
68
+ from opensportslib.apis import LocalizationModel
69
 
70
+ my_model = LocalizationModel(
71
+ config="/path/to/localization.yaml",
72
+ 👉 weights="OpenSportsLab/OSL-loc-snbas-2023-e2e", # optional
73
  )
74
 
75
+ best_checkpoint = my_model.train(
76
+ train_set="/path/to/train.json",
77
+ valid_set="/path/to/valid.json",
 
78
  )
79
+
80
+ print(best_checkpoint)
81
  ```
82
 
83
  ### Run inference
84
 
85
  ```python
86
+ from opensportslib.apis import LocalizationModel
87
+
88
+ my_model = LocalizationModel(
89
+ config="/path/to/localization.yaml",
90
+ weights="OpenSportsLab/OSL-loc-snbas-2023-e2e",
91
+ )
92
+
93
+ predictions = my_model.infer(
94
+ test_set="/path/to/test.json",
95
+ )
96
 
97
+ saved_predictions = my_model.save_predictions(
98
+ output_path="/path/to/predictions.json",
99
+ predictions=predictions,
100
  )
101
 
102
+ metrics = my_model.evaluate(
103
+ test_set="/path/to/test.json",
104
+ predictions=saved_predictions,
 
105
  )
106
 
107
  print(metrics)