Skip to main content

handle the case where sns fails

ID
7cfc8de
date
2023-04-16 15:00:31+00:00
author
Alex Chan <alex@alexwlchan.net>
parent
3b78b68
message
handle the case where sns fails
changed files
1 file, 32 additions, 6 deletions

Changed files

aws/bulk_sns_publish (3811) → aws/bulk_sns_publish (4523)

diff --git a/aws/bulk_sns_publish b/aws/bulk_sns_publish
index 154f750..5017617 100755
--- a/aws/bulk_sns_publish
+++ b/aws/bulk_sns_publish
@@ -22,9 +22,11 @@ This script provides a convenient wrapper for doing so.
 """
 
 import argparse
+import functools
+import itertools
 import os
-import secrets
 import sys
+import uuid
 
 import boto3
 import more_itertools
@@ -77,14 +79,28 @@ def get_session(*, topic_arn):
         return boto3.Session()
 
 
+def chunked_iterable(iterable, size):
+    """
+    Break an iterable into pieces of the given size.
+
+    See https://alexwlchan.net/2018/iterating-in-fixed-size-chunks/
+    """
+    it = iter(iterable)
+    while True:
+        chunk = tuple(itertools.islice(it, size))
+        if not chunk:
+            break
+        yield chunk
+
+
 def get_batch_entries(path):
     """
     Given a file which contains one notification per line, generate a series
     of values that can be passed as the `PublishBatchRequestEntries` argument
     to the `Sns.publish_batch` method.
     """
-    for batch in more_itertools.chunked(open(path), n=10):
-        yield [{"Id": secrets.token_hex(), "Message": line.strip()} for line in batch]
+    for batch in chunked_iterable(open(path), size=10):
+        yield [{"Id": str(uuid.uuid4()), "Message": line.strip()} for line in batch]
 
 
 def parse_args():
@@ -103,6 +119,17 @@ def parse_args():
     return parser.parse_args()
 
 
+def publish_batch(sns_client, topic_arn, batch_entries):
+    resp = sns_client.publish_batch(
+                    TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
+                )
+
+    # This is to account for any failures in sending messages to SNS.
+    # I've never actually had this happen in practice so what should happen
+    # here is a little bit
+    assert len(resp['Failed']) == 0, resp
+
+
 def publish_messages(*, input_file, topic_arn):
     sess = get_session(topic_arn=topic_arn)
 
@@ -117,10 +144,9 @@ def publish_messages(*, input_file, topic_arn):
 
     with tqdm.tqdm(total=total_entries) as pbar:
         for (batch, _) in concurrently(
-            handler=lambda batch_entries: sns_client.publish_batch(
-                TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
-            ),
+            handler=functools.partial(publish_batch, sns_client, topic_arn),
             inputs=get_batch_entries(input_file),
+            max_concurrency=1
         ):
             pbar.update(len(batch))