meg-huggingface
commited on
Commit
·
3d16b0d
1
Parent(s):
7d70d90
Inference endpoint figuring
Browse files
src/backend/inference_endpoint.py
CHANGED
|
@@ -6,83 +6,102 @@ from huggingface_hub import create_inference_endpoint, get_inference_endpoint
|
|
| 6 |
from src.backend.run_toxicity_eval import get_generation
|
| 7 |
from src.logging import setup_logger
|
| 8 |
import requests
|
|
|
|
| 9 |
logging.basicConfig(level=logging.DEBUG)
|
| 10 |
logger = setup_logger(__name__)
|
| 11 |
-
TIMEOUT=20
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
logger.info("Creating endpoint %s..." % endpoint_name)
|
| 16 |
# TODO(mm): Handle situation where it's paused
|
| 17 |
try:
|
| 18 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
except requests.exceptions.HTTPError as e:
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
-
logger.debug("Hit error")
|
| 25 |
logger.debug(e)
|
| 26 |
sys.exit()
|
| 27 |
endpoint.fetch()
|
| 28 |
-
logger.info("Endpoint status: %s." %
|
| 29 |
-
if endpoint.status ==
|
| 30 |
# Send a request to wake it up.
|
| 31 |
get_generation(endpoint.url, "Wake up")
|
| 32 |
sleep(TIMEOUT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
i = 0
|
| 34 |
-
while endpoint.status in [
|
|
|
|
| 35 |
if i >= 20:
|
| 36 |
logger.info("Model failed to respond. Exiting.")
|
| 37 |
sys.exit()
|
| 38 |
-
logger.debug(
|
|
|
|
| 39 |
sleep(TIMEOUT)
|
| 40 |
endpoint.fetch()
|
| 41 |
logger.debug("Endpoint status: %s." % (endpoint.status))
|
| 42 |
i += 1
|
| 43 |
-
logger.info("Endpoint created:")
|
| 44 |
-
logger.info(endpoint)
|
| 45 |
-
generation_url = endpoint.url
|
| 46 |
-
return generation_url
|
| 47 |
|
| 48 |
|
| 49 |
-
def update_endpoint_exception(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
endpoint =
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
logger.debug("Attempting a new instance type.")
|
| 62 |
-
if instance_type == "nvidia-l4":
|
| 63 |
-
# Try a larger, different, more expensive GPU.
|
| 64 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
| 65 |
-
repository=repository,
|
| 66 |
-
framework=framework, task=task,
|
| 67 |
-
accelerator=accelerator,
|
| 68 |
-
vendor=vendor, region=region,
|
| 69 |
-
type=type,
|
| 70 |
-
instance_size="x1",
|
| 71 |
-
instance_type="nvidia-a100")
|
| 72 |
-
elif instance_type == "a100" and instance_size == "x1":
|
| 73 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
| 74 |
-
repository=repository,
|
| 75 |
-
framework=framework, task=task,
|
| 76 |
-
accelerator=accelerator,
|
| 77 |
-
vendor=vendor, region=region,
|
| 78 |
-
type=type,
|
| 79 |
-
instance_size="x4",
|
| 80 |
-
instance_type="nvidia-a10g")
|
| 81 |
-
else:
|
| 82 |
-
logger.info("Getting expensive to try to run this model without human oversight. Exiting.")
|
| 83 |
-
sys.exit()
|
| 84 |
return endpoint
|
| 85 |
|
| 86 |
|
| 87 |
if __name__ == '__main__':
|
| 88 |
-
generation_url = create_endpoint(
|
|
|
|
| 6 |
from src.backend.run_toxicity_eval import get_generation
|
| 7 |
from src.logging import setup_logger
|
| 8 |
import requests
|
| 9 |
+
|
| 10 |
logging.basicConfig(level=logging.DEBUG)
|
| 11 |
logger = setup_logger(__name__)
|
| 12 |
+
TIMEOUT = 20
|
| 13 |
+
|
| 14 |
|
| 15 |
+
def create_endpoint(endpoint_name, repository, framework='pytorch',
|
| 16 |
+
task='text-generation', accelerator='gpu', vendor='aws',
|
| 17 |
+
region='us-east-1', type='protected', instance_size='x4',
|
| 18 |
+
instance_type='nvidia-l4'):
|
| 19 |
logger.info("Creating endpoint %s..." % endpoint_name)
|
| 20 |
# TODO(mm): Handle situation where it's paused
|
| 21 |
try:
|
| 22 |
+
endpoint = create_inference_endpoint(endpoint_name,
|
| 23 |
+
repository=repository,
|
| 24 |
+
framework=framework, task=task,
|
| 25 |
+
accelerator=accelerator,
|
| 26 |
+
vendor=vendor, region=region,
|
| 27 |
+
type=type,
|
| 28 |
+
instance_size=instance_size,
|
| 29 |
+
instance_type=instance_type)
|
| 30 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
| 31 |
+
# Workload with the same name already exists error.
|
| 32 |
+
# Use it again, just make sure it has the right settings.
|
| 33 |
+
logger.debug("Hit error:")
|
| 34 |
+
logger.debug(e)
|
| 35 |
+
logger.debug("Attempting to update with the given parameters.")
|
| 36 |
+
endpoint = get_inference_endpoint(endpoint_name)
|
| 37 |
+
endpoint.update(repository=repository,
|
| 38 |
+
framework=framework, task=task,
|
| 39 |
+
accelerator=accelerator,
|
| 40 |
+
vendor=vendor, region=region,
|
| 41 |
+
type=type,
|
| 42 |
+
instance_size=instance_size,
|
| 43 |
+
instance_type=instance_type)
|
| 44 |
except requests.exceptions.HTTPError as e:
|
| 45 |
+
# Not enough compute, or wrong compute
|
| 46 |
+
logger.debug("Hit error:")
|
| 47 |
+
logger.debug(e)
|
| 48 |
+
endpoint = update_endpoint_exception(endpoint)
|
| 49 |
except Exception as e:
|
| 50 |
+
logger.debug("Hit unaccounted-for error")
|
| 51 |
logger.debug(e)
|
| 52 |
sys.exit()
|
| 53 |
endpoint.fetch()
|
| 54 |
+
logger.info("Endpoint status: %s." % endpoint.status)
|
| 55 |
+
if endpoint.status == 'scaledToZero':
|
| 56 |
# Send a request to wake it up.
|
| 57 |
get_generation(endpoint.url, "Wake up")
|
| 58 |
sleep(TIMEOUT)
|
| 59 |
+
elif endpoint.status == 'failed':
|
| 60 |
+
logger.info("Endpoint failed, attempting to change compute.")
|
| 61 |
+
endpoint = update_endpoint_exception(endpoint)
|
| 62 |
+
wait_for_endpoint(endpoint)
|
| 63 |
+
if endpoint.status == 'failed':
|
| 64 |
+
logger.info("Endpoint failed, attempting to change compute.")
|
| 65 |
+
endpoint = update_endpoint_exception(endpoint)
|
| 66 |
+
wait_for_endpoint(endpoint)
|
| 67 |
+
logger.info("Endpoint created:")
|
| 68 |
+
logger.info(endpoint)
|
| 69 |
+
generation_url = endpoint.url
|
| 70 |
+
if generation_url is None:
|
| 71 |
+
logger.debug("Failed to create an endpoint. Exiting.")
|
| 72 |
+
sys.exit()
|
| 73 |
+
return generation_url
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def wait_for_endpoint(endpoint):
|
| 77 |
i = 0
|
| 78 |
+
while endpoint.status in ['pending',
|
| 79 |
+
'initializing']: # not in ['failed', 'running', 'scaledToZero']
|
| 80 |
if i >= 20:
|
| 81 |
logger.info("Model failed to respond. Exiting.")
|
| 82 |
sys.exit()
|
| 83 |
+
logger.debug(
|
| 84 |
+
"Waiting %d seconds to check again if the endpoint is running." % TIMEOUT)
|
| 85 |
sleep(TIMEOUT)
|
| 86 |
endpoint.fetch()
|
| 87 |
logger.debug("Endpoint status: %s." % (endpoint.status))
|
| 88 |
i += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
+
def update_endpoint_exception(endpoint):
|
| 92 |
+
raw_info = endpoint.raw
|
| 93 |
+
cur_instance_size = raw_info['compute']['instanceSize']
|
| 94 |
+
cur_instance_type = raw_info['compute']['instanceType']
|
| 95 |
+
if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
|
| 96 |
+
endpoint.update(instance_size='x1', instance_type='nvidia-a100')
|
| 97 |
+
elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
|
| 98 |
+
endpoint.update(instance_size='x4', instance_type='nvidia-a10g')
|
| 99 |
+
else:
|
| 100 |
+
logger.info(
|
| 101 |
+
"Getting expensive to try to run this model without human oversight. Exiting.")
|
| 102 |
+
sys.exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
return endpoint
|
| 104 |
|
| 105 |
|
| 106 |
if __name__ == '__main__':
|
| 107 |
+
generation_url = create_endpoint('this-is-a-test', 'Qwen/Qwen2-7B')
|