| | |
| | import argparse |
| | from huggingface_hub import HfApi |
| |
|
| |
|
| | def main(api, model_id): |
| | info = api.list_repo_refs(model_id) |
| | branches = set([b.name for b in info.branches]) - set(["main"]) |
| |
|
| | return list(branches) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | DESCRIPTION = """ |
| | Simple utility to get all branches from a repo |
| | """ |
| | parser = argparse.ArgumentParser(description=DESCRIPTION) |
| | parser.add_argument( |
| | "--model_id", |
| | type=str, |
| | help="The name of the model on the hub to retrieve the branches from. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | model_id = args.model_id |
| | api = HfApi() |
| | branches = main(api, model_id) |
| |
|
| | if "non-ema" in branches: |
| | print(model_id) |
| | |
| | |
| | |
| |
|