Prefect worker for executing flow runs as ECS tasks.
Get started by creating a work pool:
$ prefect work-pool create --type ecs my-ecs-pool
Then, you can start a worker for the pool:
$ prefect worker start --pool my-ecs-pool
It's common to deploy the worker as an ECS task as well. However, you can run the worker
locally to get started.
The worker may work without any additional configuration, but it is dependent on your
specific AWS setup and we'd recommend opening the work pool editor in the UI to see the
available options.
By default, the worker will register a task definition for each flow run and run a task
in your default ECS cluster using AWS Fargate. Fargate requires tasks to configure
subnets, which we will infer from your default VPC. If you do not have a default VPC,
you must provide a VPC ID or manually setup the network configuration for your tasks.
Note, the worker caches task definitions for each deployment to avoid excessive
registration. The worker will check that the cached task definition is compatible with
your configuration before using it.
The launch type option can be used to run your tasks in different modes. For example,
FARGATE_SPOT can be used to use spot instances for your Fargate tasks or EC2 can be
used to run your tasks on a cluster backed by EC2 instances.
Generally, it is very useful to enable CloudWatch logging for your ECS tasks; this can
help you debug task failures. To enable CloudWatch logging, you must provide an
execution role ARN with permissions to create and write to log streams. See the
configure_cloudwatch_logs field documentation for details.
The worker can be configured to use an existing task definition by setting the task
definition arn variable or by providing a "taskDefinition" in the task run request. When
a task definition is provided, the worker will never create a new task definition which
may result in variables that are templated into the task definition payload being
ignored.
classECSJobConfiguration(BaseJobConfiguration):""" Job configuration for an ECS worker. """aws_credentials:Optional[AwsCredentials]=Field(default_factory=AwsCredentials)task_definition:Optional[Dict[str,Any]]=Field(template=_default_task_definition_template())task_run_request:Dict[str,Any]=Field(template=_default_task_run_request_template())configure_cloudwatch_logs:Optional[bool]=Field(default=None)cloudwatch_logs_options:Dict[str,str]=Field(default_factory=dict)cloudwatch_logs_prefix:Optional[str]=Field(default=None)network_configuration:Dict[str,Any]=Field(default_factory=dict)stream_output:Optional[bool]=Field(default=None)task_start_timeout_seconds:int=Field(default=300)task_watch_poll_interval:float=Field(default=5.0)auto_deregister_task_definition:bool=Field(default=False)vpc_id:Optional[str]=Field(default=None)container_name:Optional[str]=Field(default=None)cluster:Optional[str]=Field(default=None)match_latest_revision_in_family:bool=Field(default=False)@root_validatordeftask_run_request_requires_arn_if_no_task_definition_given(cls,values)->dict:""" If no task definition is provided, a task definition ARN must be present on the task run request. """ifnotvalues.get("task_run_request",{}).get("taskDefinition")andnotvalues.get("task_definition"):raiseValueError("A task definition must be provided if a task definition ARN is not ""present on the task run request.")returnvalues@root_validatordefcontainer_name_default_from_task_definition(cls,values)->dict:""" Infers the container name from the task definition if not provided. """ifvalues.get("container_name")isNone:values["container_name"]=_container_name_from_task_definition(values.get("task_definition"))# We may not have a name here still; for example if someone is using a task# definition arn. In that case, we'll perform similar logic later to find# the name to treat as the "orchestration" container.returnvalues@root_validator(pre=True)defset_default_configure_cloudwatch_logs(cls,values:dict)->dict:""" Streaming output generally requires CloudWatch logs to be configured. To avoid entangled arguments in the simple case, `configure_cloudwatch_logs` defaults to matching the value of `stream_output`. """configure_cloudwatch_logs=values.get("configure_cloudwatch_logs")ifconfigure_cloudwatch_logsisNone:values["configure_cloudwatch_logs"]=values.get("stream_output")returnvalues@root_validatordefconfigure_cloudwatch_logs_requires_execution_role_arn(cls,values:dict)->dict:""" Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """if(values.get("configure_cloudwatch_logs")andnotvalues.get("execution_role_arn")# TODO: Does not match# Do not raise if they've linked to another task definition or provided# it without using our shortcutsandnotvalues.get("task_run_request",{}).get("taskDefinition")andnot(values.get("task_definition")or{}).get("executionRoleArn")):raiseValueError("An `execution_role_arn` must be provided to use ""`configure_cloudwatch_logs` or `stream_logs`.")returnvalues@root_validatordefcloudwatch_logs_options_requires_configure_cloudwatch_logs(cls,values:dict)->dict:""" Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """ifvalues.get("cloudwatch_logs_options")andnotvalues.get("configure_cloudwatch_logs"):raiseValueError("`configure_cloudwatch_log` must be enabled to use ""`cloudwatch_logs_options`.")returnvalues@root_validatordefnetwork_configuration_requires_vpc_id(cls,values:dict)->dict:""" Enforces a `vpc_id` is provided when custom network configuration mode is enabled for network settings. """ifvalues.get("network_configuration")andnotvalues.get("vpc_id"):raiseValueError("You must provide a `vpc_id` to enable custom `network_configuration`.")returnvalues
Enforces that an execution role arn is provided (or could be provided by a
runtime task definition) when configuring logging.
Source code in prefect_aws/workers/ecs_worker.py
353354355356357358359360361362363364365366367368
@root_validatordefcloudwatch_logs_options_requires_configure_cloudwatch_logs(cls,values:dict)->dict:""" Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """ifvalues.get("cloudwatch_logs_options")andnotvalues.get("configure_cloudwatch_logs"):raiseValueError("`configure_cloudwatch_log` must be enabled to use ""`cloudwatch_logs_options`.")returnvalues
@root_validatordefconfigure_cloudwatch_logs_requires_execution_role_arn(cls,values:dict)->dict:""" Enforces that an execution role arn is provided (or could be provided by a runtime task definition) when configuring logging. """if(values.get("configure_cloudwatch_logs")andnotvalues.get("execution_role_arn")# TODO: Does not match# Do not raise if they've linked to another task definition or provided# it without using our shortcutsandnotvalues.get("task_run_request",{}).get("taskDefinition")andnot(values.get("task_definition")or{}).get("executionRoleArn")):raiseValueError("An `execution_role_arn` must be provided to use ""`configure_cloudwatch_logs` or `stream_logs`.")returnvalues
Infers the container name from the task definition if not provided.
Source code in prefect_aws/workers/ecs_worker.py
301302303304305306307308309310311312313314315
@root_validatordefcontainer_name_default_from_task_definition(cls,values)->dict:""" Infers the container name from the task definition if not provided. """ifvalues.get("container_name")isNone:values["container_name"]=_container_name_from_task_definition(values.get("task_definition"))# We may not have a name here still; for example if someone is using a task# definition arn. In that case, we'll perform similar logic later to find# the name to treat as the "orchestration" container.returnvalues
Enforces a vpc_id is provided when custom network configuration mode is
enabled for network settings.
Source code in prefect_aws/workers/ecs_worker.py
370371372373374375376377378379380
@root_validatordefnetwork_configuration_requires_vpc_id(cls,values:dict)->dict:""" Enforces a `vpc_id` is provided when custom network configuration mode is enabled for network settings. """ifvalues.get("network_configuration")andnotvalues.get("vpc_id"):raiseValueError("You must provide a `vpc_id` to enable custom `network_configuration`.")returnvalues
Streaming output generally requires CloudWatch logs to be configured.
To avoid entangled arguments in the simple case, configure_cloudwatch_logs
defaults to matching the value of stream_output.
Source code in prefect_aws/workers/ecs_worker.py
317318319320321322323324325326327328
@root_validator(pre=True)defset_default_configure_cloudwatch_logs(cls,values:dict)->dict:""" Streaming output generally requires CloudWatch logs to be configured. To avoid entangled arguments in the simple case, `configure_cloudwatch_logs` defaults to matching the value of `stream_output`. """configure_cloudwatch_logs=values.get("configure_cloudwatch_logs")ifconfigure_cloudwatch_logsisNone:values["configure_cloudwatch_logs"]=values.get("stream_output")returnvalues
If no task definition is provided, a task definition ARN must be present on the
task run request.
Source code in prefect_aws/workers/ecs_worker.py
286287288289290291292293294295296297298299
@root_validatordeftask_run_request_requires_arn_if_no_task_definition_given(cls,values)->dict:""" If no task definition is provided, a task definition ARN must be present on the task run request. """ifnotvalues.get("task_run_request",{}).get("taskDefinition")andnotvalues.get("task_definition"):raiseValueError("A task definition must be provided if a task definition ARN is not ""present on the task run request.")returnvalues
classECSVariables(BaseVariables):""" Variables for templating an ECS job. """task_definition_arn:Optional[str]=Field(default=None,description=("An identifier for an existing task definition to use. If set, options that"" require changes to the task definition will be ignored. All contents of ""the task definition in the job configuration will be ignored."),)env:Dict[str,Optional[str]]=Field(title="Environment Variables",default_factory=dict,description=("Environment variables to provide to the task run. These variables are set ""on the Prefect container at task runtime. These will not be set on the ""task definition."),)aws_credentials:AwsCredentials=Field(title="AWS Credentials",default_factory=AwsCredentials,description=("The AWS credentials to use to connect to ECS. If not provided, credentials"" will be inferred from the local environment following AWS's boto client's"" rules."),)cluster:Optional[str]=Field(default=None,description=("The ECS cluster to run the task in. An ARN or name may be provided. If ""not provided, the default cluster will be used."),)family:Optional[str]=Field(default=None,description=("A family for the task definition. If not provided, it will be inferred ""from the task definition. If the task definition does not have a family, ""the name will be generated. When flow and deployment metadata is ""available, the generated name will include their names. Values for this ""field will be slugified to match AWS character requirements."),)launch_type:Optional[Literal["FARGATE","EC2","EXTERNAL","FARGATE_SPOT"]]=Field(default=ECS_DEFAULT_LAUNCH_TYPE,description=("The type of ECS task run infrastructure that should be used. Note that"" 'FARGATE_SPOT' is not a formal ECS launch type, but we will configure"" the proper capacity provider strategy if set here."),)capacity_provider_strategy:Optional[List[CapacityProvider]]=Field(default_factory=list,description=("The capacity provider strategy to use when running the task. ""If a capacity provider strategy is specified, the selected launch"" type will be ignored."),)image:Optional[str]=Field(default=None,description=("The image to use for the Prefect container in the task. If this value is ""not null, it will override the value in the task definition. This value ""defaults to a Prefect base image matching your local versions."),)cpu:int=Field(title="CPU",default=None,description=("The amount of CPU to provide to the ECS task. Valid amounts are ""specified in the AWS documentation. If not provided, a default value of "f"{ECS_DEFAULT_CPU} will be used unless present on the task definition."),)memory:int=Field(default=None,description=("The amount of memory to provide to the ECS task. Valid amounts are ""specified in the AWS documentation. If not provided, a default value of "f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition."),)container_name:str=Field(default=None,description=("The name of the container flow run orchestration will occur in. If not "f"specified, a default value of {ECS_DEFAULT_CONTAINER_NAME} will be used ""and if that is not found in the task definition the first container will ""be used."),)task_role_arn:str=Field(title="Task Role ARN",default=None,description=("A role to attach to the task run. This controls the permissions of the ""task while it is running."),)execution_role_arn:str=Field(title="Execution Role ARN",default=None,description=("An execution role to use for the task. This controls the permissions of ""the task when it is launching. If this value is not null, it will ""override the value in the task definition. An execution role must be ""provided to capture logs from the container."),)vpc_id:Optional[str]=Field(title="VPC ID",default=None,description=("The AWS VPC to link the task run to. This is only applicable when using ""the 'awsvpc' network mode for your task. FARGATE tasks require this ""network mode, but for EC2 tasks the default network mode is 'bridge'. ""If using the 'awsvpc' network mode and this field is null, your default ""VPC will be used. If no default VPC can be found, the task run will fail."),)configure_cloudwatch_logs:bool=Field(default=None,description=("If enabled, the Prefect container will be configured to send its output ""to the AWS CloudWatch logs service. This functionality requires an ""execution role with logs:CreateLogStream, logs:CreateLogGroup, and ""logs:PutLogEvents permissions. The default for this field is `False` ""unless `stream_output` is set."),)cloudwatch_logs_options:Dict[str,str]=Field(default_factory=dict,description=("When `configure_cloudwatch_logs` is enabled, this setting may be used to"" pass additional options to the CloudWatch logs configuration or override"" the default options. See the [AWS"" documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options)"# noqa" for available options. "),)cloudwatch_logs_prefix:Optional[str]=Field(default=None,description=("When `configure_cloudwatch_logs` is enabled, this setting may be used to"" set a prefix for the log group. If not provided, the default prefix will"" be `prefect-logs_<work_pool_name>_<deployment_id>`. If"" `awslogs-stream-prefix` is present in `Cloudwatch logs options` this"" setting will be ignored."),)network_configuration:Dict[str,Any]=Field(default_factory=dict,description=("When `network_configuration` is supplied it will override ECS Worker's""awsvpcConfiguration that defined in the ECS task executing your workload. ""See the [AWS documentation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-ecs-service-awsvpcconfiguration.html)"# noqa" for available options."),)stream_output:bool=Field(default=None,description=("If enabled, logs will be streamed from the Prefect container to the local ""console. Unless you have configured AWS CloudWatch logs manually on your ""task definition, this requires the same prerequisites outlined in ""`configure_cloudwatch_logs`."),)task_start_timeout_seconds:int=Field(default=300,description=("The amount of time to watch for the start of the ECS task ""before marking it as failed. The task must enter a RUNNING state to be ""considered started."),)task_watch_poll_interval:float=Field(default=5.0,description=("The amount of time to wait between AWS API calls while monitoring the ""state of an ECS task."),)auto_deregister_task_definition:bool=Field(default=False,description=("If enabled, any task definitions that are created by this block will be ""deregistered. Existing task definitions linked by ARN will never be ""deregistered. Deregistering a task definition does not remove it from ""your AWS account, instead it will be marked as INACTIVE."),)match_latest_revision_in_family:bool=Field(default=False,description=("If enabled, the most recent active revision in the task definition ""family will be compared against the desired ECS task configuration. ""If they are equal, the existing task definition will be used instead ""of registering a new one. If no family is specified the default family "f'"{ECS_DEFAULT_FAMILY}" will be used.'),)
classECSWorker(BaseWorker):""" A Prefect worker to run flow runs as ECS tasks. """type="ecs"job_configuration=ECSJobConfigurationjob_configuration_variables=ECSVariables_description=("Execute flow runs within containers on AWS ECS. Works with EC2 ""and Fargate clusters. Requires an AWS account.")_display_name="AWS Elastic Container Service"_documentation_url="https://prefecthq.github.io/prefect-aws/ecs_worker/"_logo_url="https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png"# noqaasyncdefrun(self,flow_run:"FlowRun",configuration:ECSJobConfiguration,task_status:Optional[anyio.abc.TaskStatus]=None,)->ECSWorkerResult:""" Runs a given flow run on the current worker. """ecs_client=awaitrun_sync_in_worker_thread(self._get_client,configuration,"ecs")logger=self.get_flow_run_logger(flow_run)(task_arn,cluster_arn,task_definition,is_new_task_definition,)=awaitrun_sync_in_worker_thread(self._create_task_and_wait_for_start,logger,ecs_client,configuration,flow_run,)# The task identifier is "{cluster}::{task}" where we use the configured cluster# if set to preserve matching by name rather than arn# Note "::" is used despite the Prefect standard being ":" because ARNs contain# single colons.identifier=((configuration.clusterifconfiguration.clusterelsecluster_arn)+"::"+task_arn)iftask_status:task_status.started(identifier)status_code=awaitrun_sync_in_worker_thread(self._watch_task_and_get_exit_code,logger,configuration,task_arn,cluster_arn,task_definition,is_new_task_definitionandconfiguration.auto_deregister_task_definition,ecs_client,)returnECSWorkerResult(identifier=identifier,# If the container does not start the exit code can be null but we must# still report a status code. We use a -1 to indicate a special code.status_code=status_codeifstatus_codeisnotNoneelse-1,)def_get_client(self,configuration:ECSJobConfiguration,client_type:Union[str,ClientType])->_ECSClient:""" Get a boto3 client of client_type. Will use a cached client if one exists. """returnconfiguration.aws_credentials.get_client(client_type)def_create_task_and_wait_for_start(self,logger:logging.Logger,ecs_client:_ECSClient,configuration:ECSJobConfiguration,flow_run:FlowRun,)->Tuple[str,str,dict,bool]:""" Register the task definition, create the task run, and wait for it to start. Returns a tuple of - The task ARN - The task's cluster ARN - The task definition - A bool indicating if the task definition is newly registered """task_definition_arn=configuration.task_run_request.get("taskDefinition")new_task_definition_registered=Falseifnottask_definition_arn:task_definition=self._prepare_task_definition(configuration,region=ecs_client.meta.region_name,flow_run=flow_run)(task_definition_arn,new_task_definition_registered,)=self._get_or_register_task_definition(logger,ecs_client,configuration,flow_run,task_definition)else:task_definition=self._retrieve_task_definition(logger,ecs_client,task_definition_arn)ifconfiguration.task_definition:logger.warning("Ignoring task definition in configuration since task definition"" ARN is provided on the task run request.")self._validate_task_definition(task_definition,configuration)_TASK_DEFINITION_CACHE[flow_run.deployment_id]=task_definition_arnlogger.info(f"Using ECS task definition {task_definition_arn!r}...")logger.debug(f"Task definition {json.dumps(task_definition,indent=2,default=str)}")task_run_request=self._prepare_task_run_request(configuration,task_definition,task_definition_arn,)logger.info("Creating ECS task run...")logger.debug("Task run request"f"{json.dumps(mask_api_key(task_run_request),indent=2,default=str)}")try:task=self._create_task_run(ecs_client,task_run_request)task_arn=task["taskArn"]cluster_arn=task["clusterArn"]exceptExceptionasexc:self._report_task_run_creation_failure(configuration,task_run_request,exc)raiselogger.info("Waiting for ECS task run to start...")self._wait_for_task_start(logger,configuration,task_arn,cluster_arn,ecs_client,timeout=configuration.task_start_timeout_seconds,)returntask_arn,cluster_arn,task_definition,new_task_definition_registereddef_get_or_register_task_definition(self,logger:logging.Logger,ecs_client:_ECSClient,configuration:ECSJobConfiguration,flow_run:FlowRun,task_definition:dict,)->Tuple[str,bool]:"""Get or register a task definition for the given flow run. Returns a tuple of the task definition ARN and a bool indicating if the task definition is newly registered. """cached_task_definition_arn=_TASK_DEFINITION_CACHE.get(flow_run.deployment_id)new_task_definition_registered=Falseifcached_task_definition_arn:try:cached_task_definition=self._retrieve_task_definition(logger,ecs_client,cached_task_definition_arn)ifnotcached_task_definition["status"]=="ACTIVE"ornotself._task_definitions_equal(task_definition,cached_task_definition):cached_task_definition_arn=NoneexceptException:cached_task_definition_arn=Noneif(notcached_task_definition_arnandconfiguration.match_latest_revision_in_family):family_name=task_definition.get("family",ECS_DEFAULT_FAMILY)try:task_definition_from_family=self._retrieve_task_definition(logger,ecs_client,family_name)iftask_definition_from_familyandself._task_definitions_equal(task_definition,task_definition_from_family):cached_task_definition_arn=task_definition_from_family["taskDefinitionArn"]exceptException:cached_task_definition_arn=Noneifnotcached_task_definition_arn:task_definition_arn=self._register_task_definition(logger,ecs_client,task_definition)new_task_definition_registered=Trueelse:task_definition_arn=cached_task_definition_arnreturntask_definition_arn,new_task_definition_registereddef_watch_task_and_get_exit_code(self,logger:logging.Logger,configuration:ECSJobConfiguration,task_arn:str,cluster_arn:str,task_definition:dict,deregister_task_definition:bool,ecs_client:_ECSClient,)->Optional[int]:""" Wait for the task run to complete and retrieve the exit code of the Prefect container. """# Wait for completion and stream logstask=self._wait_for_task_finish(logger,configuration,task_arn,cluster_arn,task_definition,ecs_client,)ifderegister_task_definition:ecs_client.deregister_task_definition(taskDefinition=task["taskDefinitionArn"])container_name=(configuration.container_nameor_container_name_from_task_definition(task_definition)orECS_DEFAULT_CONTAINER_NAME)# Check the status code of the Prefect containercontainer=_get_container(task["containers"],container_name)assert(containerisnotNone),f"'{container_name}' container missing from task: {task}"status_code=container.get("exitCode")self._report_container_status_code(logger,container_name,status_code)returnstatus_codedef_report_container_status_code(self,logger:logging.Logger,name:str,status_code:Optional[int])->None:""" Display a log for the given container status code. """ifstatus_codeisNone:logger.error(f"Task exited without reporting an exit status for container {name!r}.")elifstatus_code==0:logger.info(f"Container {name!r} exited successfully.")else:logger.warning(f"Container {name!r} exited with non-zero exit code {status_code}.")def_report_task_run_creation_failure(self,configuration:ECSJobConfiguration,task_run:dict,exc:Exception)->None:""" Wrap common AWS task run creation failures with nicer user-facing messages. """# AWS generates exception types at runtime so they must be captured a bit# differently than normal.if"ClusterNotFoundException"instr(exc):cluster=task_run.get("cluster","default")raiseRuntimeError(f"Failed to run ECS task, cluster {cluster!r} not found. ""Confirm that the cluster is configured in your region.")fromexcelif("No Container Instances"instr(exc)andtask_run.get("launchType")=="EC2"):cluster=task_run.get("cluster","default")raiseRuntimeError(f"Failed to run ECS task, cluster {cluster!r} does not appear to ""have any container instances associated with it. Confirm that you ""have EC2 container instances available.")fromexcelif("failed to validate logger args"instr(exc)and"AccessDeniedException"instr(exc)andconfiguration.configure_cloudwatch_logs):raiseRuntimeError("Failed to run ECS task, the attached execution role does not appear"" to have sufficient permissions. Ensure that the execution role"f" {configuration.execution_role!r} has permissions"" logs:CreateLogStream, logs:CreateLogGroup, and logs:PutLogEvents.")else:raisedef_validate_task_definition(self,task_definition:dict,configuration:ECSJobConfiguration)->None:""" Ensure that the task definition is compatible with the configuration. Raises `ValueError` on incompatibility. Returns `None` on success. """launch_type=configuration.task_run_request.get("launchType",ECS_DEFAULT_LAUNCH_TYPE)if(launch_type!="EC2"and"FARGATE"notintask_definition["requiresCompatibilities"]):raiseValueError("Task definition does not have 'FARGATE' in 'requiresCompatibilities'"f" and cannot be used with launch type {launch_type!r}")iflaunch_type=="FARGATE"orlaunch_type=="FARGATE_SPOT":# Only the 'awsvpc' network mode is supported when using FARGATEnetwork_mode=task_definition.get("networkMode")ifnetwork_mode!="awsvpc":raiseValueError(f"Found network mode {network_mode!r} which is not compatible with "f"launch type {launch_type!r}. Use either the 'EC2' launch ""type or the 'awsvpc' network mode.")ifconfiguration.configure_cloudwatch_logsandnottask_definition.get("executionRoleArn"):raiseValueError("An execution role arn must be set on the task definition to use ""`configure_cloudwatch_logs` or `stream_logs` but no execution role ""was found on the task definition.")def_register_task_definition(self,logger:logging.Logger,ecs_client:_ECSClient,task_definition:dict,)->str:""" Register a new task definition with AWS. Returns the ARN. """logger.info("Registering ECS task definition...")logger.debug("Task definition request"f"{json.dumps(task_definition,indent=2,default=str)}")response=ecs_client.register_task_definition(**task_definition)returnresponse["taskDefinition"]["taskDefinitionArn"]def_retrieve_task_definition(self,logger:logging.Logger,ecs_client:_ECSClient,task_definition:str,):""" Retrieve an existing task definition from AWS. """iftask_definition.startswith("arn:aws:ecs:"):logger.info(f"Retrieving ECS task definition {task_definition!r}...")else:logger.info("Retrieving most recent active revision from "f"ECS task family {task_definition!r}...")response=ecs_client.describe_task_definition(taskDefinition=task_definition)returnresponse["taskDefinition"]def_wait_for_task_start(self,logger:logging.Logger,configuration:ECSJobConfiguration,task_arn:str,cluster_arn:str,ecs_client:_ECSClient,timeout:int,)->dict:""" Waits for an ECS task run to reach a RUNNING status. If a STOPPED status is reached instead, an exception is raised indicating the reason that the task run did not start. """fortaskinself._watch_task_run(logger,configuration,task_arn,cluster_arn,ecs_client,until_status="RUNNING",timeout=timeout,):# TODO: It is possible that the task has passed _through_ a RUNNING# status during the polling interval. In this case, there is not an# exception to raise.iftask["lastStatus"]=="STOPPED":code=task.get("stopCode")reason=task.get("stoppedReason")# Generate a dynamic exception type from the AWS nameraisetype(code,(RuntimeError,),{})(reason)returntaskdef_wait_for_task_finish(self,logger:logging.Logger,configuration:ECSJobConfiguration,task_arn:str,cluster_arn:str,task_definition:dict,ecs_client:_ECSClient,):""" Watch an ECS task until it reaches a STOPPED status. If configured, logs from the Prefect container are streamed to stderr. Returns a description of the task on completion. """can_stream_output=Falsecontainer_name=(configuration.container_nameor_container_name_from_task_definition(task_definition)orECS_DEFAULT_CONTAINER_NAME)ifconfiguration.stream_output:container_def=_get_container(task_definition["containerDefinitions"],container_name)ifnotcontainer_def:logger.warning("Prefect container definition not found in ""task definition. Output cannot be streamed.")elifnotcontainer_def.get("logConfiguration"):logger.warning("Logging configuration not found on task. ""Output cannot be streamed.")elifnotcontainer_def["logConfiguration"].get("logDriver")=="awslogs":logger.warning("Logging configuration uses unsupported "" driver {container_def['logConfiguration'].get('logDriver')!r}. ""Output cannot be streamed.")else:# Prepare to stream the outputlog_config=container_def["logConfiguration"]["options"]logs_client=self._get_client(configuration,"logs")can_stream_output=True# Track the last log timestamp to prevent double displaylast_log_timestamp:Optional[int]=None# Determine the name of the stream as "prefix/container/run-id"stream_name="/".join([log_config["awslogs-stream-prefix"],container_name,task_arn.rsplit("/")[-1],])self._logger.info(f"Streaming output from container {container_name!r}...")fortaskinself._watch_task_run(logger,configuration,task_arn,cluster_arn,ecs_client,current_status="RUNNING",):ifconfiguration.stream_outputandcan_stream_output:# On each poll for task run status, also retrieve available logslast_log_timestamp=self._stream_available_logs(logger,logs_client,log_group=log_config["awslogs-group"],log_stream=stream_name,last_log_timestamp=last_log_timestamp,)returntaskdef_stream_available_logs(self,logger:logging.Logger,logs_client:Any,log_group:str,log_stream:str,last_log_timestamp:Optional[int]=None,)->Optional[int]:""" Stream logs from the given log group and stream since the last log timestamp. Will continue on paginated responses until all logs are returned. Returns the last log timestamp which can be used to call this method in the future. """last_log_stream_token="NO-TOKEN"next_log_stream_token=None# AWS will return the same token that we send once the end of the paginated# response is reachedwhilelast_log_stream_token!=next_log_stream_token:last_log_stream_token=next_log_stream_tokenrequest={"logGroupName":log_group,"logStreamName":log_stream,}iflast_log_stream_tokenisnotNone:request["nextToken"]=last_log_stream_tokeniflast_log_timestampisnotNone:# Bump the timestamp by one ms to avoid retrieving the last log againrequest["startTime"]=last_log_timestamp+1try:response=logs_client.get_log_events(**request)exceptException:logger.error(f"Failed to read log events with request {request}",exc_info=True,)returnlast_log_timestamplog_events=response["events"]forlog_eventinlog_events:# TODO: This doesn't forward to the local logger, which can be# bad for customizing handling and understanding where the# log is coming from, but it avoid nesting logger information# when the content is output from a Prefect logger on the# running infrastructureprint(log_event["message"],file=sys.stderr)if(last_log_timestampisNoneorlog_event["timestamp"]>last_log_timestamp):last_log_timestamp=log_event["timestamp"]next_log_stream_token=response.get("nextForwardToken")ifnotlog_events:# Stop reading pages if there was no databreakreturnlast_log_timestampdef_watch_task_run(self,logger:logging.Logger,configuration:ECSJobConfiguration,task_arn:str,cluster_arn:str,ecs_client:_ECSClient,current_status:str="UNKNOWN",until_status:str=None,timeout:int=None,)->Generator[None,None,dict]:""" Watches an ECS task run by querying every `poll_interval` seconds. After each query, the retrieved task is yielded. This function returns when the task run reaches a STOPPED status or the provided `until_status`. Emits a log each time the status changes. """last_status=status=current_statust0=time.time()whilestatus!=until_status:tasks=ecs_client.describe_tasks(tasks=[task_arn],cluster=cluster_arn,include=["TAGS"])["tasks"]iftasks:task=tasks[0]status=task["lastStatus"]ifstatus!=last_status:logger.info(f"ECS task status is {status}.")yieldtask# No point in continuing if the status is finalifstatus=="STOPPED":breaklast_status=statuselse:# Intermittently, the task will not be described. We wat to respect the# watch timeout though.logger.debug("Task not found.")elapsed_time=time.time()-t0iftimeoutisnotNoneandelapsed_time>timeout:raiseRuntimeError(f"Timed out after {elapsed_time}s while watching task for status "f"{until_statusor'STOPPED'}.")time.sleep(configuration.task_watch_poll_interval)def_get_or_generate_family(self,task_definition:dict,flow_run:FlowRun)->str:""" Gets or generate a family for the task definition. """family=task_definition.get("family")ifnotfamily:assertself._work_pool_nameandflow_run.deployment_idfamily=(f"{ECS_DEFAULT_FAMILY}_{self._work_pool_name}_{flow_run.deployment_id}")slugify(family,max_length=255,regex_pattern=r"[^a-zA-Z0-9-_]+",)returnfamilydef_prepare_task_definition(self,configuration:ECSJobConfiguration,region:str,flow_run:FlowRun,)->dict:""" Prepare a task definition by inferring any defaults and merging overrides. """task_definition=copy.deepcopy(configuration.task_definition)# Configure the Prefect runtime containertask_definition.setdefault("containerDefinitions",[])# Remove empty container definitionstask_definition["containerDefinitions"]=[dfordintask_definition["containerDefinitions"]ifd]container_name=configuration.container_nameifnotcontainer_name:container_name=(_container_name_from_task_definition(task_definition)orECS_DEFAULT_CONTAINER_NAME)container=_get_container(task_definition["containerDefinitions"],container_name)ifcontainerisNone:ifcontainer_name!=ECS_DEFAULT_CONTAINER_NAME:raiseValueError(f"Container {container_name!r} not found in task definition.")# Look for a container without a nameforcontainerintask_definition["containerDefinitions"]:if"name"notincontainer:container["name"]=container_namebreakelse:container={"name":container_name}task_definition["containerDefinitions"].append(container)# Image is required so make sure it's presentcontainer.setdefault("image",get_prefect_image_name())# Remove any keys that have been explicitly "unset"unset_keys={keyforkey,valueinconfiguration.env.items()ifvalueisNone}foritemintuple(container.get("environment",[])):ifitem["name"]inunset_keysoritem["value"]isNone:container["environment"].remove(item)ifconfiguration.configure_cloudwatch_logs:prefix=f"prefect-logs_{self._work_pool_name}_{flow_run.deployment_id}"container["logConfiguration"]={"logDriver":"awslogs","options":{"awslogs-create-group":"true","awslogs-group":"prefect","awslogs-region":region,"awslogs-stream-prefix":(configuration.cloudwatch_logs_prefixorprefix),**configuration.cloudwatch_logs_options,},}task_definition["family"]=self._get_or_generate_family(task_definition,flow_run)# CPU and memory are required in some cases, retrieve the value to usecpu=task_definition.get("cpu")orECS_DEFAULT_CPUmemory=task_definition.get("memory")orECS_DEFAULT_MEMORYlaunch_type=configuration.task_run_request.get("launchType",ECS_DEFAULT_LAUNCH_TYPE)iflaunch_type=="FARGATE"orlaunch_type=="FARGATE_SPOT":# Task level memory and cpu are required when using fargatetask_definition["cpu"]=str(cpu)task_definition["memory"]=str(memory)# The FARGATE compatibility is required if it will be used as as launch typerequires_compatibilities=task_definition.setdefault("requiresCompatibilities",[])if"FARGATE"notinrequires_compatibilities:task_definition["requiresCompatibilities"].append("FARGATE")# Only the 'awsvpc' network mode is supported when using FARGATE# However, we will not enforce that here if the user has set ittask_definition.setdefault("networkMode","awsvpc")eliflaunch_type=="EC2":# Container level memory and cpu are required when using ec2container.setdefault("cpu",cpu)container.setdefault("memory",memory)# Ensure set values are cast to integerscontainer["cpu"]=int(container["cpu"])container["memory"]=int(container["memory"])# Ensure set values are cast to stringsiftask_definition.get("cpu"):task_definition["cpu"]=str(task_definition["cpu"])iftask_definition.get("memory"):task_definition["memory"]=str(task_definition["memory"])returntask_definitiondef_load_network_configuration(self,vpc_id:Optional[str],configuration:ECSJobConfiguration)->dict:""" Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ec2_client=self._get_client(configuration,"ec2")vpc_message="the default VPC"ifnotvpc_idelsef"VPC with ID {vpc_id}"ifnotvpc_id:# Retrieve the default VPCdescribe={"Filters":[{"Name":"isDefault","Values":["true"]}]}else:describe={"VpcIds":[vpc_id]}vpcs=ec2_client.describe_vpcs(**describe)["Vpcs"]ifnotvpcs:help_message=("Pass an explicit `vpc_id` or configure a default VPC."ifnotvpc_idelse"Check that the VPC exists in the current region.")raiseValueError(f"Failed to find {vpc_message}. ""Network configuration cannot be inferred. "+help_message)vpc_id=vpcs[0]["VpcId"]subnets=ec2_client.describe_subnets(Filters=[{"Name":"vpc-id","Values":[vpc_id]}])["Subnets"]ifnotsubnets:raiseValueError(f"Failed to find subnets for {vpc_message}. ""Network configuration cannot be inferred.")return{"awsvpcConfiguration":{"subnets":[s["SubnetId"]forsinsubnets],"assignPublicIp":"ENABLED","securityGroups":[],}}def_custom_network_configuration(self,vpc_id:str,network_configuration:dict,configuration:ECSJobConfiguration,)->dict:""" Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ec2_client=self._get_client(configuration,"ec2")vpc_message=f"VPC with ID {vpc_id}"vpcs=ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs")ifnotvpcs:raiseValueError(f"Failed to find {vpc_message}. "+"Network configuration cannot be inferred. "+"Pass an explicit `vpc_id`.")vpc_id=vpcs[0]["VpcId"]subnets=ec2_client.describe_subnets(Filters=[{"Name":"vpc-id","Values":[vpc_id]}])["Subnets"]ifnotsubnets:raiseValueError(f"Failed to find subnets for {vpc_message}. "+"Network configuration cannot be inferred.")subnet_ids=[subnet["SubnetId"]forsubnetinsubnets]config_subnets=network_configuration.get("subnets",[])ifnotall(conf_sninsubnet_idsforconf_sninconfig_subnets):raiseValueError(f"Subnets {config_subnets} not found within {vpc_message}."+"Please check that VPC is associated with supplied subnets.")return{"awsvpcConfiguration":network_configuration}def_prepare_task_run_request(self,configuration:ECSJobConfiguration,task_definition:dict,task_definition_arn:str,)->dict:""" Prepare a task run request payload. """task_run_request=deepcopy(configuration.task_run_request)task_run_request.setdefault("taskDefinition",task_definition_arn)asserttask_run_request["taskDefinition"]==task_definition_arncapacityProviderStrategy=task_run_request.get("capacityProviderStrategy")ifcapacityProviderStrategy:# Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqaself._logger.warning("Found capacityProviderStrategy. ""Removing launchType from task run request.")task_run_request.pop("launchType",None)eliftask_run_request.get("launchType")=="FARGATE_SPOT":# Should not be provided at all for FARGATE SPOTtask_run_request.pop("launchType",None)# A capacity provider strategy is required for FARGATE SPOTtask_run_request["capacityProviderStrategy"]=[{"capacityProvider":"FARGATE_SPOT","weight":1}]overrides=task_run_request.get("overrides",{})container_overrides=overrides.get("containerOverrides",[])# Ensure the network configuration is present if using awsvpc for network modeif(task_definition.get("networkMode")=="awsvpc"andnottask_run_request.get("networkConfiguration")andnotconfiguration.network_configuration):task_run_request["networkConfiguration"]=self._load_network_configuration(configuration.vpc_id,configuration)# Use networkConfiguration if supplied by userif(task_definition.get("networkMode")=="awsvpc"andconfiguration.network_configurationandconfiguration.vpc_id):task_run_request["networkConfiguration"]=self._custom_network_configuration(configuration.vpc_id,configuration.network_configuration,configuration,)# Ensure the container name is set if not provided at template timecontainer_name=(configuration.container_nameor_container_name_from_task_definition(task_definition)orECS_DEFAULT_CONTAINER_NAME)ifcontainer_overridesandnotcontainer_overrides[0].get("name"):container_overrides[0]["name"]=container_name# Ensure configuration command is respected post-templatingorchestration_container=_get_container(container_overrides,container_name)iforchestration_container:# Override the command if given on the configurationifconfiguration.command:orchestration_container["command"]=configuration.command# Clean up templated variable formattingforcontainerincontainer_overrides:ifisinstance(container.get("command"),str):container["command"]=shlex.split(container["command"])ifisinstance(container.get("environment"),dict):container["environment"]=[{"name":k,"value":v}fork,vincontainer["environment"].items()]# Remove null values — they're not allowed by AWScontainer["environment"]=[itemforitemincontainer.get("environment",[])ifitem["value"]isnotNone]ifisinstance(task_run_request.get("tags"),dict):task_run_request["tags"]=[{"key":k,"value":v}fork,vintask_run_request["tags"].items()]ifoverrides.get("cpu"):overrides["cpu"]=str(overrides["cpu"])ifoverrides.get("memory"):overrides["memory"]=str(overrides["memory"])# Ensure configuration tags and env are respected post-templatingtags=[itemforitemintask_run_request.get("tags",[])ifitem["key"]notinconfiguration.labels.keys()]+[{"key":k,"value":v}fork,vinconfiguration.labels.items()ifvisnotNone]# Slugify tags keys and valuestags=[{"key":slugify(item["key"],regex_pattern=_TAG_REGEX,allow_unicode=True,lowercase=False,),"value":slugify(item["value"],regex_pattern=_TAG_REGEX,allow_unicode=True,lowercase=False,),}foritemintags]iftags:task_run_request["tags"]=tagsiforchestration_container:environment=[itemforiteminorchestration_container.get("environment",[])ifitem["name"]notinconfiguration.env.keys()]+[{"name":k,"value":v}fork,vinconfiguration.env.items()ifvisnotNone]ifenvironment:orchestration_container["environment"]=environment# Remove empty container overridesoverrides["containerOverrides"]=[vforvincontainer_overridesifv]returntask_run_request@retry(stop=stop_after_attempt(MAX_CREATE_TASK_RUN_ATTEMPTS),wait=wait_fixed(CREATE_TASK_RUN_MIN_DELAY_SECONDS)+wait_random(CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS,CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS,),reraise=True,)def_create_task_run(self,ecs_client:_ECSClient,task_run_request:dict)->str:""" Create a run of a task definition. Returns the task run ARN. """task=ecs_client.run_task(**task_run_request)iftask["failures"]:raiseRuntimeError(f"Failed to run ECS task: {task['failures'][0]['reason']}")elifnottask["tasks"]:raiseRuntimeError("Failed to run ECS task: no tasks or failures were returned.")returntask["tasks"][0]def_task_definitions_equal(self,taskdef_1,taskdef_2)->bool:""" Compare two task definitions. Since one may come from the AWS API and have populated defaults, we do our best to homogenize the definitions without changing their meaning. """iftaskdef_1==taskdef_2:returnTrueiftaskdef_1isNoneortaskdef_2isNone:returnFalsetaskdef_1=copy.deepcopy(taskdef_1)taskdef_2=copy.deepcopy(taskdef_2)fortaskdefin(taskdef_1,taskdef_2):# Set defaults that AWS would set after registrationcontainer_definitions=taskdef.get("containerDefinitions",[])essential=any(container.get("essential")forcontainerincontainer_definitions)ifnotessential:container_definitions[0].setdefault("essential",True)taskdef.setdefault("networkMode","bridge")_drop_empty_keys_from_task_definition(taskdef_1)_drop_empty_keys_from_task_definition(taskdef_2)# Clear fields that change on registration for comparisonforfieldinECS_POST_REGISTRATION_FIELDS:taskdef_1.pop(field,None)taskdef_2.pop(field,None)returntaskdef_1==taskdef_2asyncdefkill_infrastructure(self,configuration:ECSJobConfiguration,infrastructure_pid:str,grace_seconds:int=30,)->None:""" Kill a task running on ECS. Args: infrastructure_pid: A cluster and task arn combination. This should match a value yielded by `ECSWorker.run`. """ifgrace_seconds!=30:self._logger.warning(f"Kill grace period of {grace_seconds}s requested, but AWS does not ""support dynamic grace period configuration so 30s will be used. ""See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-agent-config.html for configuration of grace periods."# noqa)cluster,task=parse_identifier(infrastructure_pid)awaitrun_sync_in_worker_thread(self._stop_task,configuration,cluster,task)def_stop_task(self,configuration:ECSJobConfiguration,cluster:str,task:str)->None:""" Stop a running ECS task. """ifconfiguration.clusterisnotNoneandcluster!=configuration.cluster:raiseInfrastructureNotAvailable("Cannot stop ECS task: this infrastructure block has access to "f"cluster {configuration.cluster!r} but the task is running in cluster "f"{cluster!r}.")ecs_client=self._get_client(configuration,"ecs")try:ecs_client.stop_task(cluster=cluster,task=task)exceptExceptionasexc:# Raise a special exception if the task does not existif"ClusterNotFound"instr(exc):raiseInfrastructureNotFound(f"Cannot stop ECS task: the cluster {cluster!r} could not be found.")fromexcif"not find task"instr(exc)or"referenced task was not found"instr(exc):raiseInfrastructureNotFound(f"Cannot stop ECS task: the task {task!r} could not be found in "f"cluster {cluster!r}.")fromexcif"no registered tasks"instr(exc):raiseInfrastructureNotFound(f"Cannot stop ECS task: the cluster {cluster!r} has no tasks.")fromexc# Reraise unknown exceptionsraise
asyncdefkill_infrastructure(self,configuration:ECSJobConfiguration,infrastructure_pid:str,grace_seconds:int=30,)->None:""" Kill a task running on ECS. Args: infrastructure_pid: A cluster and task arn combination. This should match a value yielded by `ECSWorker.run`. """ifgrace_seconds!=30:self._logger.warning(f"Kill grace period of {grace_seconds}s requested, but AWS does not ""support dynamic grace period configuration so 30s will be used. ""See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-agent-config.html for configuration of grace periods."# noqa)cluster,task=parse_identifier(infrastructure_pid)awaitrun_sync_in_worker_thread(self._stop_task,configuration,cluster,task)
asyncdefrun(self,flow_run:"FlowRun",configuration:ECSJobConfiguration,task_status:Optional[anyio.abc.TaskStatus]=None,)->ECSWorkerResult:""" Runs a given flow run on the current worker. """ecs_client=awaitrun_sync_in_worker_thread(self._get_client,configuration,"ecs")logger=self.get_flow_run_logger(flow_run)(task_arn,cluster_arn,task_definition,is_new_task_definition,)=awaitrun_sync_in_worker_thread(self._create_task_and_wait_for_start,logger,ecs_client,configuration,flow_run,)# The task identifier is "{cluster}::{task}" where we use the configured cluster# if set to preserve matching by name rather than arn# Note "::" is used despite the Prefect standard being ":" because ARNs contain# single colons.identifier=((configuration.clusterifconfiguration.clusterelsecluster_arn)+"::"+task_arn)iftask_status:task_status.started(identifier)status_code=awaitrun_sync_in_worker_thread(self._watch_task_and_get_exit_code,logger,configuration,task_arn,cluster_arn,task_definition,is_new_task_definitionandconfiguration.auto_deregister_task_definition,ecs_client,)returnECSWorkerResult(identifier=identifier,# If the container does not start the exit code can be null but we must# still report a status code. We use a -1 to indicate a special code.status_code=status_codeifstatus_codeisnotNoneelse-1,)
Splits identifier into its cluster and task components, e.g.
input "cluster_name::task_arn" outputs ("cluster_name", "task_arn").
Source code in prefect_aws/workers/ecs_worker.py
217218219220221222223
defparse_identifier(identifier:str)->ECSIdentifier:""" Splits identifier into its cluster and task components, e.g. input "cluster_name::task_arn" outputs ("cluster_name", "task_arn"). """cluster,task=identifier.split("::",maxsplit=1)returnECSIdentifier(cluster,task)