今回のサンプルコードは Github で公開しています。
Github/pagination-drift-specimen
概要
データベースに記録してある大量のデータを一括で処理する際に発生しがちなパターンとして以下のものがあります。
- 対象データが大量(10万件とか)のため、アプリケーション側でループを書き、1回のループ内では対象データを
LIMIT 1000 OFFSET ...
のSQLで取得する - アプリケーション側での処理で、処理対象のデータの抽出SQL の
WHERE
句の条件を、条件に当てはまらなくなるようにUPDATE
する
このような処理を書いてしまうと、ループで全対象データを全て処理できずに漏れが発生します。
具体的な例
次のようなテーブル product があるとします。
- id: int プライマリーキー
- name: char 商品名
- active: bool 有効フラグ
- expiry: datetime 有効期限
100万件の商品データが入っているものとしてください。
expiry
が過去になった商品すべての active
を False
にするという要件があり、 SQL の UPDATE
一発ではできずアプリケーションコードで行う必要がある(ロギングとか通知とか他のデータを連動して変更する必要があるとか)場合、
SELECT *
FROM product
WHERE active = TRUE
AND expiry < NOW()
LIMIT 1000 OFFSET ?
ORDER BY id
で対象データを取得し、アプリケーションコード側で
UPDATE product
SET active = FALSE
WHERE id = ?
というSQLを順次発行していくことになります。
2ループ目は、検索対象SQL を OFFSET 1000
としたいところですが、 そうするとすでに1ループ目で処理した対象は active = FALSE
となっており、SELECT
の WHERE
の対象からは外れてしまっているので、1000件ごとに処理されないレコードが発生してしまいます。
Djangoの例
Django で書くと、次のようになります。
from django.core.paginator import Paginator
from ..models import Product
# 有効期限が過ぎている商品を全て検索する
product_qs = Product.objects.filter(
active=True,
expiry__lt=timezone.now()
).order_by('id')
# Paginator を使って1000件ごとに SQL を発行する
paginator = Paginator(product_qs, 1000)
for page_number in paginator.page_range:
for product in paginator.page(page_number).object_list:
# 該当した Product を無効化する
product.active = False
# ロギングとか通知とかの処理...
product.save(update_fields=['active'])
上記コードの問題に、すぐ気づくのは難しいのではないでしょうか。
私は自信がありません。
この現象の名前
この現象に固有の名前がついているかはわかりませんでしたが、ChatGPT に質問したところ「Pagination Drift」という名前だという返答がきました。
本当か? Pagination Drift でググってもそれらしい記事は見つけられません。
ともあれ納得できる名前ではあります。
今回は、この現象を「パジネーションドリフト」と呼びます。
より一般的な名前を知ってたら教えてください。
サンプルコード
今回の問題のサンプルコードを Django で作っています。
この記事の前半では、「商品の有効期限が過去になったものの active
を FALSE
にする」という条件ですが、サンプルコードでは 「商品名に「不適切」という文字列が含まれている商品の active
を FALSE
にする」という仕様に変えています。日時が関係すると検証が少し困難になるため、よりシンプルな条件にしたかったからです。
今回のサンプルコードは Github で公開しています。
Github/pagination-drift-specimen
サンプルコードの解説
./manage.py s01_refresh_products
のコード で、10000件のダミーの商品レコードを作ります。 そのうちの半分は、商品名に「不適切」という文言が含まれます。
s01_refresh_products
ダミー商品を 10000件作成
def make_product_name(i):
"""
テキトーな商品名を作成。iが偶数の場合、「この商品は不適切です。」という文言が入る。
"""
if i % 2 == 0:
return f'商品 {i} この商品は不適切です。'
else:
return f'商品 {i}'
Product.objects.all().delete()
Product.objects.bulk_create(
[
Product(name=make_product_name(i))
for i in range(1, 10001)
]
)
./manage.py s02_active_inactive_counts
を実行すると、Product を active 別に集計して内容をコンソールに表示します。
s02_active_inactive_counts
qs = Product.objects.values('active').annotate(
count=Count('active')).order_by('active')
for r in qs:
print(f"active={r['active']}, Count: {r['count']}")
./manage.py s11_deactivate_inappropriate_products
は、name
に「不適切」が含まれる Product の active
をすべて False
に更新する意図があるスクリプトですが、パジネーションドリフトする不具合があるため、全てのレコードを対象に処理ができず、処理漏れが発生します。
s11_deactivate_inappropriate_products
product_qs = Product.objects.filter(
name__contains='不適切', active=True
).order_by('id')
# Paginator を使って1000件ごとに SQL を発行する
paginator = Paginator(product_qs, 1000)
for page_number in paginator.page_range:
for product in paginator.page(page_number).object_list:
# 該当した Product を無効化する
product.active = False
product.save(update_fields=['active'])
self.stdout.write(
self.style.SUCCESS(
f'{product.id} {product.name} を無効化しました。'))
↑このコードには不具合がある。
対策
A. 全対象レコードのプライマリーキーをあらかじめ取得しておく
s21_preload_all_primary_keys
# 対象のレコードのプライマリーキーを全部取得する
product_pks = list(Product.objects.filter(
name__contains='不適切', active=True
).order_by('id').values_list('id', flat=True))
# Paginator を使って1000件ごとに切り分けながら処理する
paginator = Paginator(product_pks, 1000)
for page_number in paginator.page_range:
products = Product.objects.filter(
id__in=paginator.page(page_number).object_list)
for product in products:
# 該当した Product を無効化する
product.active = False
product.save(update_fields=['active'])
self.stdout.write(
self.style.SUCCESS(
f'{product.id} {product.name} を無効化しました。'))
% ./manage.py s01_refresh_products
% ./manage.py s02_active_inactive_counts
active=True, Count: 10000
% ./manage.py s21_preload_all_primary_keys
% ./manage.py s02_active_inactive_counts
active=False, Count: 5000
active=True, Count: 5000
対象のレコード5000件を漏れなく処理できました。
対象のPKをすべてメモリに蓄えるため、メモリを圧迫する処理ですが、試しに int を10万個入れたリストのメモリ容量を確認すると 800KB程度です。
>>> sys.getsizeof(list(range(100000)))
800056
そのため、実行環境にもよりますが、50万件まではこの方法で全く問題無いと思います。
数十万件で、PILで1枚画像処理をするのと同程度のメモリ消費量です。
B. 対象データのPK範囲で処理を分ける(1回のループ内の処理件数は変動する)
対象データ ( Product ) のプライマリーキーが連番になっている場合、1回の処理で #1〜#2000 を取得して処理し、次のループで #2001〜#4000 を取得して処理する、という方法もあります。
処理がシンプルでコードが書きやすく不具合を起こしにくいですが、1ループ内の対象数がまちまちになるので効率化という点で少し低く、雑な印象があります。
s22_segmented_by_product_pk
aggregated = Product.objects.aggregate(
min_id=Min('id'), max_id=Max('id'))
product_qs = Product.objects.filter(
name__contains='不適切', active=True
).order_by('id')
page_size = 2000
for offset in range(
aggregated['min_id'], aggregated['max_id'] + 1, page_size
):
products = product_qs.filter(
id__gte=offset, id__lt=offset + page_size)
for product in products:
# 該当した Product を無効化する
product.active = False
product.save(update_fields=['active'])
self.stdout.write(
self.style.SUCCESS(
f'{product.id} {product.name} を無効化しました。'))
C. 前回ループの最後のPK 以降を次のループで処理する
まず対象データ ( Product ) をプライマリーキーでソートして対象データの1000件を取得し、次回ループは前のループで処理したプライマリーキーの最後のもの以降を検索条件に加えて行う方法です。
処理が少し読みにくいですが、毎回きっちり1000件ごとに処理されるので効率は良いと思います。
s23_filter_last_pk
product_qs = Product.objects.filter(
name__contains='不適切', active=True
).order_by('id')
page_size = 1000
last_max_id = 0
while True:
products = list(product_qs.filter(
id__gt=last_max_id)[:page_size])
if not products:
break
last_max_id = products[-1].id
for product in products:
# 該当した Product を無効化する
product.active = False
product.save(update_fields=['active'])
self.stdout.write(
self.style.SUCCESS(
f'{product.id} {product.name} を無効化しました。'))
D. 外部に記録する
規模が大きい場合や、他の要件の解決のために、処理対象を SQL の一時テーブルに記録してから処理をしたり、Redis に入れてから処理する場合もあります。
SQL のテンポラリーテーブルを使う
SQLのテンポラリーテーブルを使う場合、Django の ORM ではサポートされていないため、生 SQL で `CREATE TEMPORARY TABLE ...`` を発行する必要があります。今回は詳しくは書きません。
Redis のリストを使う
Redis を使う場合は、リスト構造で保存できるため、およそ以下のようなコードになります。今回は詳しくは書きません。
from django_redis import get_redis_connection
redis_connection = get_redis_connection('default')
key = ...
redis_connection.delete(key)
redis_connection.rpush(key, *対象データ)
print(redis_connection.llen(key)) # 件数を表示
for .....:
対象データ = redis_connection.lrange(key, start, end)
...