Optimizing the use of limited AI training accelerators — Part 2
This post was created in collaboration with Max Rabin.
This is the second part of a series of posts on the topic of maximizing the utility of scarce AI resources. In the first post we noted the increasing limitations on the ability to scale up AI resources at will and, as a consequence, the growing trend of AI development teams to guarantee AI compute capacity by means such as building up an in-house AI server farm and/or reserving dedicated instances in the cloud. The scarcity of AI compute resources motivates the design of specialized scheduling solutions to minimize idle time and prioritize critical workloads. Please see our previous post in which we proposed a detailed list of requirements for such solutions. The approach we took there was to leverage the existing priority-based scheduler that comes with Kubernetes and align our training development workflow to its use. In this post we explore the option of maintaining our existing framework for training AI models and enhancing it with our own custom implementation of a priority-based scheduler. Importantly, the need for this type of solution is often motivated not just by the scarcity of AI resources, but also by the desire to increase control over the orchestration and prioritization of training workloads so as to reduce development costs. For example, even in a scenario of abundant capacity, you may choose to limit your use to a fixed number of training instances so as to cap your training expenditure.
For the purposes of this post, we will assume that our training framework of choice is AWS’s managed service for AI model training, Amazon SageMaker. The solution we will propose will use additional AWS services such as Amazon DynamoDB and AWS Lambda. The choice to demonstrate our solution using AWS services should not be viewed as endorsement. There are many cloud-based service offerings available and the best one for you will depend on the particular details of your project. Similar solutions to the one that we will describe can be designed on other cloud-based environments and/or using alternative cloud-based services.
The Traditional Method for Starting Up SageMaker Training Jobs
Traditionally, we would start up a SageMaker training job using the Amazon SageMaker Python SDK. In the code block below we use the SageMaker SDK (version 2.208) to run a PyTorch training workload on a single instance of type p5.48xlarge.
from sagemaker.pytorch import PyTorch
# define job
estimator = PyTorch(
role='<sagemaker role>',
entry_point='train.py',
instance_type='ml.p5.48xlarge',
instance_count=1,
framework_version='2.0.1',
py_version='py310',
tags=[{'Key': 'priority', 'Value': '100'}
)
# start job
estimator.fit()
When the estimator.fit() function is called, the SageMaker library uploads our code to Amazon S3 and then transforms the request to a boto3 SageMaker client create_training_job request (see here).
This method for starting up training jobs is dependent on the availability of the requested resources for its success. In our scenario of scarce AI resources, it is likely to fail more often than not. Although this can be partially mitigated by retaining provisioned compute instances for successive workloads, the API does not provide the appropriate tooling for maximizing their utility. Let’s suppose that we wish to utilize precisely two p5.48xlarge instances. To simplify our discussion, let’s assume that each training workload runs on a single instance. Typically, during an AI model development cycle there will be periods when there are more than two training workloads that are waiting to be processed. The existing API would try to start up a third p5.48xlarge instance and would most likely fail due to its limited availability. Even when there is instance availability, we may wish to limit our training to just our two designated instances to increase our control over the costs of training.
We require a new API for submitting jobs for training, one that does not immediately start up a new p5.48xlarge instance, but rather enters the jobs to a priority queue. And we need an associated job scheduler that manages the use of our two resources while prioritizing critical workloads.
Importantly, please note that as of the time of this writing, Amazon SageMaker does not support the option of training on reserved Amazon EC2 instances. And although Amazon SageMaker Savings Plans has similar properties to instance reservations, it does not guarantee instance capacity. In a previous post we addressed this limitation and proposed using SageMaker managed warm pools as an alternative method for retaining access to provisioned instances. For the remainder of the post, we will assume that we are able to attain two instances of our choice whether it be through this or some other method.
Priority-Based Scheduling for Amazon SageMaker
In this section we will describe the components of our proposed solution. We will use the AWS Serverless Application Model (SAM) specification. More specifically, we will create an AWS SAM template YAML file and gradually add the AWS resources that we need. Please see the documentation for details on how to define and deploy serverless solutions using AWS SAM.
A Private API for Submitting Training Jobs
We start by using Amazon API Gateway to define a private REST API for submitting training job requests. We name the API training-job-queue. Later, we will add a POST method called add-job and modify our training-job creation code to use this method instead of the SageMaker client create_training_job API. The code block below contains the definition of the private API resource in SAM. In practice you will likely want to specify access limitations to the API and/or a method of authorization.
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Resources:
InternalAPI:
Type: AWS::Serverless::Api
# Auth: # Add access control to API
EndpointConfiguration:
Type: PRIVATE
# VPCEndpointIds: # Specify VPC Endpoint(s)
Name: training-job-queue
StageName: prod
Define an AWS DynamoDB Table for Storing Training Job Requests
We will use an Amazon DynamoDB table named sagemaker-queue to store the submitted training workloads. Each entry will have the following fields:
- jobName: Stores the unique name of the training job.
- entryTime: Stores the date and time that the job was added.
- jobState: Stores the current state of the training job. The valid values are ‘pending’, ‘running’, and ‘preempted’.
- priority: Stores an integer value representing the relative priority of the job.
- jobDetails: Stores the details of the job request.
We define our DynamoDB table in our SAM template YAML file using the AWS::Serverless::SimpleTable resource.
DynamoSMQueue:
Type: AWS::Serverless::SimpleTable
Properties:
PrimaryKey:
Name: jobName
Type: String
TableName: sagemaker-queue
We define a function that creates a table entry from a given training job request. We assume that request contains the same contents as the input to the create_training_job API in JSON format. We further assume that the priority of the workload is entered as a key-value tag in the training job definition.
import json, boto3, datetime
dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table('sagemaker-queue')
def add_job_entry(job_json):
job_details = json.loads(job_json)
# extract job_name
job_name = job_details['TrainingJobName']
print(f'add entry {job_name}')
# get current time
entry_time = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
# default priority is 0
priority = 0
# update priority based on tags
tags = job_details['Tags']
for tag in tags:
if tag['Key'] == 'priority':
priority = int(tag['Value'])
break
# create entry
entry = {
'jobName': job_name,
'entryTime': entry_time,
'jobState': 'pending',
'priority': priority,
'jobDetails': job_json
}
table.put_item(Item=entry) #TODO handle errors
print(f'Added job {job_name} to queue')
The REST API add-job method that we will soon define will be programmed to call the add_job_entry function.
We define a second function that extracts the pending jobs from the database and returns them in order of priority. In the case that multiple jobs have the same priority, they are ordered according to the amount of time they have been waiting in the queue.
from boto3.dynamodb.conditions import Attr
# Get a list of all pending jobs sorted by priority
def get_pending_jobs():
response = table.scan(
ProjectionExpression='jobName, priority, entryTime',
FilterExpression=Attr('jobState').ne('running')
)
jobs = response.get('Items', [])
# sort jobs, first by priority (descending) and then by entryTime
sorted_jobs = sorted(jobs,
key=lambda x: (-x['priority'], x['entryTime']))
return sorted_jobs
The following utility functions will come in handy in the next sections.
# Get a jobName -> priority mapping of all running jobs
def get_running_jobs_dict():
# Get all running jobs
response = table.scan(
ProjectionExpression="jobName, priority",
FilterExpression=Attr('jobState').eq('running')
)
jobs = response.get('Items', [])
running_jobs = {job['jobName']: job['priority'] for job in jobs}
return running_jobs
# Print the queue state
def print_queue_state():
response = table.scan(
ProjectionExpression='jobName, jobState, priority'
)
jobs = response.get('Items', [])
print_table = []
for job in jobs:
print_table.append([job['jobName'], job['jobState'], job['priority']])
# sort by priority
sorted_table = sorted(print_table,
key=lambda x: -x[2])
# Print the table
from tabulate import tabulate
print(tabulate(sorted_table, headers=['Job Name', 'State', 'Priority']))
# get job details
def get_job_details(job_name):
response = table.get_item(
Key={'jobName': job_name},
ProjectionExpression='jobDetails'
)
return json.loads(response.get('Item').get('jobDetails'))
# get job state or None if the job does not exist
def get_job_state(job_name):
response = table.get_item(
Key={'jobName': job_name},
ProjectionExpression='jobState'
)
job = response.get('Item')
return job.get('jobState') if job else None
# update the job state
def update_job_state(job_name, new_state):
table.update_item(
Key={'jobName': job_name},
UpdateExpression="SET jobState = :new_state",
ExpressionAttributeValues={":new_state": new_state}
)
print(f'Update job {job_name} to {new_state}')
# remove a job entry
def remove_job(job_name):
table.delete_item(
Key={'jobName': job_name}
)
print(f'Removed job {job_name} from queue')
Both our choice of DynamoDB and its usage (e.g., our use of the Scan API rather than the Query API) assume that the overall number of jobs in our queue will be in the dozens, at most. For a larger scale solution, you may be better off with a heavier duty database (e.g., one that performs the sorting operation for you) or a more sophisticated use of DynamoDB (e.g., see here).
Define the Training Job Queue Manager
The main component of our solution is the training job scheduler. Here we implement a rather simple manager that performs the following steps:
- Extract the list of queued jobs, ordered by priority. If none exist, return.
- Discover unused instance capacity. For each free instance, start one pending job on SageMaker. If no jobs remain after that, return.
- Calculate the number of SageMaker jobs in the Stopping state. If greater than the number of pending jobs, return.
- Assess the need for preemption of running SageMaker jobs by comparing their priorities to those of our pending jobs.
# set the limit on total number of instances/jobs
MAX_CAPACITY = 2
sagemaker = boto3.client('sagemaker')
# apply a queue stamp to identify that the job came from the queue
def apply_qstamp(job_name):
return f'{job_name}-qstamp-{datetime.now().strftime("%d%H%M")}'
# strip the queue stamp
def strip_qstamp(job_name):
return job_name.split('-qstamp-')[0]
# start a SageMaker job and update job entry in queue
def start_job(job_name):
print(f'start job {job_name}')
job_details = get_job_details(job_name)
job_details['TrainingJobName'] = apply_qstamp(job_name)
if(job_details):
# start job with detail from queue
# (you may optinally overwrite fields such as the iam role)
response = sagemaker.create_training_job(**job_details)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
print(f'started job {job_name}')
update_job_state(job_name, 'running')
# preempt a SageMaker job and update job entry in queue
def preempt_job(job_name):
print(f'preempt job {job_name}')
response = sagemaker.stop_training_job(TrainingJobName=job_name)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
print(f'preempted job {job_name}')
update_job_state(strip_qstamp(job_name), 'preempted')
# get SageMaker jobs
def get_sagemaker_jobs(status):
running = sagemaker.list_training_jobs(StatusEquals=status)
return running.get('TrainingJobSummaries', [])
# queue manager
def manage_queue():
# extract pending jobs to run
pending = get_pending_jobs()
if not pending:
return
if len(pending) > MAX_CAPACITY:
pending = pending[:MAX_CAPACITY]
# get running sagemaker jobs
running = get_sagemaker_jobs('InProgress')
total_running = len(running)
# get stopping sagemaker jobs
stopping = get_sagemaker_jobs('Stopping')
total_stopping = len(stopping)
# calculate the number of free instances
free_slots = MAX_CAPACITY - total_running - total_stopping
jobs_to_start = min(len(pending), free_slots)
# for each free instance, start a job
for i in range(jobs_to_start):
start_job(pending[i].get('jobName'))
still_pending = pending[jobs_to_start:]
if not still_pending:
return
# assume that 'total_stopping' number of jobs will start soon
test_for_preemption = len(still_pending) - total_stopping
if test_for_preemption <= 0:
return
# check if preemption is required
test_priority = still_pending[total_stopping:]
running_jobs = get_running_jobs_dict()
priority_dict = {}
for job in running:
job_name = job['TrainingJobName']
priority_dict[job_name] = running_jobs[strip_qstamp(job_name)]
# sort running jobs from lowest to highest priority
sorted_running = sorted(priority_dict.items(), key=lambda item: item[1])
index = 0
while index < test_for_preemption and
test_priority[index].get('priority') > sorted_running[index][1]:
preempt_job(sorted_running[index][0])
index = index + 1
Important notes:
- Our implementation is highly optimistic in the sense that we assume that all the jobs that are inserted are valid and that we will be able to start them up on SageMaker without issue. In practice, appropriate error handling should be added (e.g., removing faulty jobs from the queue with appropriate logging).
- In a production environment, we would need to take into consideration the likely occurrence of a race condition when our queue_manager is triggered by multiple concurrent events. There are several ways of addressing this problem (e.g., see here) including enforcing atomicity (e.g., by setting our Lambda function concurrency to one), using some form of locking mechanism (e.g., as done here), or making our function idempotent. Here we have taken the approach of what we call “optimistic idempotence”, where we rely on appropriate use of the API and on the idempotency of our underlying calls to the SageMaker APIs.
- We emphasize that our implementation is naïve. In practice, we recommend a more sophisticated algorithm that 1) accounts for the use of different types of instances and jobs that require more than one instance, 2) takes all edge cases into consideration, and 3) is tailored towards the specific needs of your project.
Define the AWS Lambda Function
The next component of the solution is the Lambda function. The following code block includes the SAM definition of our serverless function. We program the function to run on two different types of events: any call to add-job on our private API gateway and a change to the state of a SageMaker training job.
ManagedTrainingJobQueue:
Type: AWS::Serverless::Function
Properties:
CodeUri: job-queue/ # the directory containing our index.py file
Handler: index.lambda_handler
Runtime: python3.12
Architectures:
- arm64 # use graviton
Policies: # allow access to SageMaker and DynamoDB
- !Sub "arn:${AWS::Partition}:iam::aws:policy/AmazonSageMakerFullAccess"
- DynamoDBCrudPolicy:
TableName: !Ref DynamoSMQueue
Events:
CreateTraining:
Type: Api
Properties:
Path: /add-job
Method: post
RestApiId: !Ref InternalAPI
SageMakerEvent:
Type: EventBridgeRule
Properties:
Pattern:
source:
- aws.sagemaker
detail-type:
- SageMaker Training Job State Change
detail:
TrainingJobStatus:
- "Completed"
- "Failed"
- "Stopped"
The lambda_handler function is implemented as follows:
def lambda_handler(event, context):
# identify source of event and take appropriate action
if 'requestContext' in event and 'apiId' in event['requestContext']:
print('Lambda triggerred by API Gateway')
job_details = json.loads(event.get('body'))
add_job_entry(job_details)
elif 'source' in event and event['source'] == 'aws.sagemaker':
print('Lambda triggerred by SageMaker job state change')
job_name = event['detail']['TrainingJobName']
job_status = event['detail']['TrainingJobStatus']
print(f'{job_name} status changed to {job_status}')
# strip qstamp from job_name
job_name = strip_qstamp(job_name)
if job_status in ['Completed' , 'Failed']:
remove_job(job_name)
elif job_status == 'Stopped':
# check if it was manually stopped or preempted by queue manager
if get_job_state(job_name) == 'preempted':
print(f'job {job_name} preemption completed')
else:
print(f'job {job_name} {job_status}, remove from queue')
remove_job(job_name)
# in all cases invoke queue manager
manage_queue()
Intercept the Create Training Job Request
The final modification required to make our solution complete is to intercept the call to the SageMaker create_training_job API and reroute it to our add-job method. We do this by overriding the _intercept_create_request function of the SageMaker Session class:
from sagemaker.pytorch import PyTorch
from sagemaker.session import Session
import requests, logging
logger = logging.getLogger('sagemaker')
def submit_to_training_queue(job):
logger.info(f'Adding training-job {job['TrainingJobName']} to queue')
logger.debug('train request: {json.dumps(job, indent=4)}')
vpce='<vpc endpoint>' # insert id of vpc endpoint
region='us-east-1' # specify region
url=f'https://{vpce}.execute-api.{region}.vpce.amazonaws.com/prod/add-job'
headers = {'x-apigw-api-id': '<api-id>'} # insert api gateway id
# submit job
response = requests.post(url, headers=headers, json=job)
class QueueTrainingJobSession(Session):
def _intercept_create_request(self, request, create, func_name = None):
"""This function intercepts the create job request
Args:
request (dict): the create job request
create (functor): a functor calls the sagemaker client create method
func_name (str): the name of the function needed intercepting
"""
if func_name == 'train':
submit_to_training_queue(request)
else:
super()._intercept_create_request(request,create,func_name)
# define job
estimator = PyTorch(
role='<sagemaker role>',
entry_point='train.py',
instance_type='ml.p5.48xlarge',
instance_count=1,
framework_version='2.0.1',
py_version='py310',
tags=[{'Key': 'priority', 'Value': '100'},
keep_alive_period_in_seconds=60, # keep warm for 1 minute
# use our custom Session class
sagemaker_session=QueueTrainingJobSession()
)
estimator.fit(wait=False)
Use Case Example
To test our solution we submit the following sequence of jobs. After each call we print the status of the queue (using the print_queue_state function) and sleep for twenty seconds.
- Start job1 with priority 1.
- Start job2 with priority 2.
- Start job3 with priority 1.
- Start job4 with priority 3.
The first two jobs are immediately submitted to SageMaker and updated to the running state. Since the third job has low priority and we have precisely two training instances, it remains in the pending state and waits its turn. After submitting the first three jobs, the queue state appears as:
Job Name State Priority
---------- ------- ----------
job2 running 2
job1 running 1
job3 pending 1
The fourth job we submit has a higher priority than all of the jobs in the queue. Consequently, the running job with the lowest priority, job1, is preempted. The corresponding SageMaker job is stopped and once the instance is released, the queue state becomes:
Job Name State Priority
---------- --------- ----------
job4 running 3
job2 running 2
job1 preempted 1
job3 pending 1
The SageMaker job running job2 is the first to finish, job2 is removed from the queue, and our preempted job is resumed:
Job Name State Priority
---------- ------- ----------
job4 running 3
job1 running 1
job3 pending 1
Once job4 is completed, it too is removed from the queue, making room for job3. The remaining jobs are also run to completion, ultimately leaving our queue empty.
Summary
The increasing difficulty of acquiring AI compute capacity has forced AI development teams to reevaluate the processes they use for training AI models. The approach we have demonstrated in this post is to augment the traditional APIs for training models with a custom-made priority queue and an associated job scheduler. Importantly, the proposal we have put forth should be viewed as a general scheme, not as a production-worthy solution. Appropriate modifications and enhancements would be required to address the specifics needs of your project.
A Priority Based Scheduler for Amazon SageMaker Training Jobs was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
A Priority Based Scheduler for Amazon SageMaker Training Jobs
Go Here to Read this Fast! A Priority Based Scheduler for Amazon SageMaker Training Jobs