Skip to main content

tidy up the ‘bulk sns publish’ script

ID
baa04f3
date
2023-04-14 23:02:25+00:00
author
Alex Chan <alex@alexwlchan.net>
parent
7579cb2
message
tidy up the 'bulk sns publish' script
changed files
1 file, 50 additions, 35 deletions

Changed files

aws/bulk_sns_publish (2870) → aws/bulk_sns_publish (3500)

diff --git a/aws/bulk_sns_publish b/aws/bulk_sns_publish
index 446e982..eb44f9f 100755
--- a/aws/bulk_sns_publish
+++ b/aws/bulk_sns_publish
@@ -21,6 +21,7 @@ This script provides a convenient wrapper for doing so.
 
 """
 
+import functools
 import os
 import secrets
 import sys
@@ -50,54 +51,68 @@ def get_aws_session(*, role_arn):
     )
 
 
-ACCOUNT_NAMES = {
-    '760097843905': 'platform',
-}
+def get_session(*, topic_arn):
+    """
+    Return a boto3 Session for publishing to SNS.
 
+    If it recognises the account which contains the topic, it will pick
+    the appropriate IAM role, otherwise it use the default boto3 Session.
+    """
+    account_names = {
+        "760097843905": "platform",
+    }
 
-@click.command()
-@click.argument("INPUT_FILE", required=True)
-@click.option("--topic-arn", required=True)
-@click.option("--parallelism", required=True, type=int)
-def main(input_file, topic_arn, parallelism):
-    def inputs():
-        with open(input_file) as messages:
-            for i, batch in enumerate(more_itertools.chunked(messages, n=10)):
-                batch = [line.rstrip() for line in batch]
+    # The arn format of an SNS topic is:
+    #
+    #       arn:aws:sns:{region}:{account_id}:{topic_name}
+    #
+    # Extract the account ID.
+    account_id = topic_arn.split(":")[4]
 
-                batch_request_entries = [
-                    {"Id": secrets.token_hex(), "Message": message} for message in batch
-                ]
+    try:
+        role_arn = (
+            f"arn:aws:iam::{account_id}:role/{account_names[account_id]}-developer"
+        )
+        return get_aws_session(role_arn=role_arn)
+    except KeyError:
+        return boto3.Session()
 
-                yield batch_request_entries
 
-    # choose the appropriate topic_arn here
+def get_batch_entries(path):
+    """
+    Given a file which contains one notification per line, generate a series
+    of values that can be passed as the `PublishBatchRequestEntries` argument
+    to the `Sns.publish_batch` method.
+    """
+    for batch in more_itertools.chunked(open(path), n=10):
+        yield [{"Id": secrets.token_hex(), "Message": line.strip()} for line in batch]
 
-    account_id = topic_arn.split(':')[4]
 
-    try:
-        role_arn = f'arn:aws:iam::{account_id}:role/{ACCOUNT_NAMES[account_id]}-developer'
-        print(f'Assuming role {role_arn}...')
-        sess = get_aws_session(
-            role_arn=role_arn
-        )
-    except KeyError:
-        sess = boto3.Session()
+@click.command()
+@click.argument("INPUT_FILE", required=True)
+@click.option("--topic-arn", required=True)
+@click.option("--parallelism", default=5, type=int)
+def main(input_file, topic_arn, parallelism):
+    sess = get_session(topic_arn=topic_arn)
 
+    # Note: creating boto3 clients isn't thread-safe, so it's important
+    # to create it once rather than creating it multiple times in the
+    # concurrently() handler.
+    #
+    # See https://github.com/boto/boto3/issues/801
     sns_client = sess.client("sns")
 
-    def publish(batch_request_entries):
-        sns_client.publish_batch(
-            TopicArn=topic_arn, PublishBatchRequestEntries=batch_request_entries
-        )
-
-    total_entries = sum(len(sns_in) for sns_in in inputs())
+    total_entries = sum(len(entries) for entries in get_batch_entries(input_file))
 
     with tqdm.tqdm(total=total_entries) as pbar:
-        for (sns_in, sns_out) in concurrently(
-            publish, inputs(), max_concurrency=parallelism
+        for (batch, _) in concurrently(
+            handler=lambda batch_entries: sns_client.publish_batch(
+                TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
+            ),
+            inputs=get_batch_entries(input_file),
+            max_concurrency=parallelism,
         ):
-            pbar.update(len(sns_in))
+            pbar.update(len(batch))
 
 
 if __name__ == "__main__":