handle the case where sns fails
- ID
7cfc8de- date
2023-04-16 15:00:31+00:00- author
Alex Chan <alex@alexwlchan.net>- parent
3b78b68- message
handle the case where sns fails- changed files
1 file, 32 additions, 6 deletions
Changed files
aws/bulk_sns_publish (3811) → aws/bulk_sns_publish (4523)
diff --git a/aws/bulk_sns_publish b/aws/bulk_sns_publish
index 154f750..5017617 100755
--- a/aws/bulk_sns_publish
+++ b/aws/bulk_sns_publish
@@ -22,9 +22,11 @@ This script provides a convenient wrapper for doing so.
"""
import argparse
+import functools
+import itertools
import os
-import secrets
import sys
+import uuid
import boto3
import more_itertools
@@ -77,14 +79,28 @@ def get_session(*, topic_arn):
return boto3.Session()
+def chunked_iterable(iterable, size):
+ """
+ Break an iterable into pieces of the given size.
+
+ See https://alexwlchan.net/2018/iterating-in-fixed-size-chunks/
+ """
+ it = iter(iterable)
+ while True:
+ chunk = tuple(itertools.islice(it, size))
+ if not chunk:
+ break
+ yield chunk
+
+
def get_batch_entries(path):
"""
Given a file which contains one notification per line, generate a series
of values that can be passed as the `PublishBatchRequestEntries` argument
to the `Sns.publish_batch` method.
"""
- for batch in more_itertools.chunked(open(path), n=10):
- yield [{"Id": secrets.token_hex(), "Message": line.strip()} for line in batch]
+ for batch in chunked_iterable(open(path), size=10):
+ yield [{"Id": str(uuid.uuid4()), "Message": line.strip()} for line in batch]
def parse_args():
@@ -103,6 +119,17 @@ def parse_args():
return parser.parse_args()
+def publish_batch(sns_client, topic_arn, batch_entries):
+ resp = sns_client.publish_batch(
+ TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
+ )
+
+ # This is to account for any failures in sending messages to SNS.
+ # I've never actually had this happen in practice so what should happen
+ # here is a little bit
+ assert len(resp['Failed']) == 0, resp
+
+
def publish_messages(*, input_file, topic_arn):
sess = get_session(topic_arn=topic_arn)
@@ -117,10 +144,9 @@ def publish_messages(*, input_file, topic_arn):
with tqdm.tqdm(total=total_entries) as pbar:
for (batch, _) in concurrently(
- handler=lambda batch_entries: sns_client.publish_batch(
- TopicArn=topic_arn, PublishBatchRequestEntries=batch_entries
- ),
+ handler=functools.partial(publish_batch, sns_client, topic_arn),
inputs=get_batch_entries(input_file),
+ max_concurrency=1
):
pbar.update(len(batch))