tidy up the ‘bulk sns publish’ script
- ID
baa04f3- date
2023-04-14 23:02:25+00:00- author
Alex Chan <alex@alexwlchan.net>- parent
7579cb2- message
tidy up the 'bulk sns publish' script- changed files
1 file, 50 additions, 35 deletions
Changed files
aws/bulk_sns_publish (2870) → aws/bulk_sns_publish (3500)
diff --git a/aws/bulk_sns_publish b/aws/bulk_sns_publish
index 446e982..eb44f9f 100755
--- a/aws/bulk_sns_publish
+++ b/aws/bulk_sns_publish
@@ -21,6 +21,7 @@ This script provides a convenient wrapper for doing so.
"""
+import functools
import os
import secrets
import sys
@@ -50,54 +51,68 @@ def get_aws_session(*, role_arn):
)
-ACCOUNT_NAMES = {
- '760097843905': 'platform',
-}
+def get_session(*, topic_arn):
+ """
+ Return a boto3 Session for publishing to SNS.
+ If it recognises the account which contains the topic, it will pick
+ the appropriate IAM role, otherwise it use the default boto3 Session.
+ """
+ account_names = {
+ "760097843905": "platform",
+ }
-@click.command()
-@click.argument("INPUT_FILE", required=True)
-@click.option("--topic-arn", required=True)
-@click.option("--parallelism", required=True, type=int)
-def main(input_file, topic_arn, parallelism):
- def inputs():
- with open(input_file) as messages:
- for i, batch in enumerate(more_itertools.chunked(messages, n=10)):
- batch = [line.rstrip() for line in batch]
+ # The arn format of an SNS topic is:
+ #
+ # arn:aws:sns:{region}:{account_id}:{topic_name}
+ #
+ # Extract the account ID.
+ account_id = topic_arn.split(":")[4]
- batch_request_entries = [
- {"Id": secrets.token_hex(), "Message": message} for message in batch
- ]
+ try:
+ role_arn = (
+ f"arn:aws:iam::{account_id}:role/{account_names[account_id]}-developer"
+ )
+ return get_aws_session(role_arn=role_arn)
+ except KeyError:
+ return boto3.Session()
- yield batch_request_entries
- # choose the appropriate topic_arn here
+def get_batch_entries(path):
+ """
+ Given a file which contains one notification per line, generate a series
+ of values that can be passed as the `PublishBatchRequestEntries` argument
+ to the `Sns.publish_batch` method.
+ """
+ for batch in more_itertools.chunked(open(path), n=10):
+ yield [{"Id": secrets.token_hex(), "Message": line.strip()} for line in batch]
- account_id = topic_arn.split(':')[4]
- try:
- role_arn = f'arn:aws:iam::{account_id}:role/{ACCOUNT_NAMES[account_id]}-developer'
- print(f'Assuming role {role_arn}...')
- sess = get_aws_session(
- role_arn=role_arn
- )
- except KeyError:
- sess = boto3.Session()
+@click.command()
+@click.argument("INPUT_FILE", required=True)
+@click.option("--topic-arn", required=True)
+@click.option("--parallelism", default=5, type=int)
+def main(input_file, topic_arn, parallelism):
+ sess = get_session(topic_arn=topic_arn)
+ # Note: creating boto3 clients isn't thread-safe, so it's important
+ # to create it once rather than creating it multiple times in the
+ # concurrently() handler.
+ #
+ # See https://github.com/boto/boto3/issues/801
sns_client = sess.client("sns")
- def publish(batch_request_entries):
- sns_client.publish_batch(
- TopicArn=topic_arn, PublishBatchRequestEntries=batch_request_entries
- )
-
- total_entries = sum(len(sns_in) for sns_in in inputs())
+ total_entries = sum(len(entries) for entries in get_batch_entries(input_file))
with tqdm.tqdm(total=total_entries) as pbar:
- for (sns_in, sns_out) in concurrently(
- publish, inputs(), max_concurrency=parallelism
+ for (batch, _) in concurrently(
+ handler=lambda batch_entries: sns_client.publish_batch(
+ TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
+ ),
+ inputs=get_batch_entries(input_file),
+ max_concurrency=parallelism,
):
- pbar.update(len(sns_in))
+ pbar.update(len(batch))
if __name__ == "__main__":