aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2024-06-03 18:47:34 -0400
committerGitHub <noreply@github.com>2024-06-03 22:47:34 +0000
commit79fae3b0a15be30d35131420f030c9a31338b357 (patch)
treef4aea481cce7e56a0886f429780baa18e049e05b
parent[3.13] gh-117657: Fix race involving immortalizing objects (GH-119927) (#120005) (diff)
downloadcpython-79fae3b0a15be30d35131420f030c9a31338b357.tar.gz
cpython-79fae3b0a15be30d35131420f030c9a31338b357.tar.bz2
cpython-79fae3b0a15be30d35131420f030c9a31338b357.zip
[3.13] gh-117657: Fix itertools.count thread safety (GH-119268) (#120007)
Fix itertools.count in free-threading mode (cherry picked from commit 87939bd5790accea77c5a81093f16f28d3f0b429) Co-authored-by: Arnon Yaari <wiggin15@yahoo.com>
-rw-r--r--Lib/test/test_itertools.py24
-rw-r--r--Modules/itertoolsmodule.c40
-rw-r--r--Tools/tsan/suppressions_free_threading.txt1
3 files changed, 54 insertions, 11 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index e243da309f..2c92d880c1 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -644,7 +644,7 @@ class TestBasicOps(unittest.TestCase):
count(1, maxsize+5); sys.exc_info()
@pickle_deprecated
- def test_count_with_stride(self):
+ def test_count_with_step(self):
self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
self.assertEqual(lzip('abc',count(start=2,step=3)),
[('a', 2), ('b', 5), ('c', 8)])
@@ -699,6 +699,28 @@ class TestBasicOps(unittest.TestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, count(i, j))
+ @threading_helper.requires_working_threading()
+ def test_count_threading(self, step=1):
+ # this test verifies multithreading consistency, which is
+ # mostly for testing builds without GIL, but nice to test anyway
+ count_to = 10_000
+ num_threads = 10
+ c = count(step=step)
+ def counting_thread():
+ for i in range(count_to):
+ next(c)
+ threads = []
+ for i in range(num_threads):
+ thread = threading.Thread(target=counting_thread)
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+ self.assertEqual(next(c), count_to * num_threads * step)
+
+ def test_count_with_step_threading(self):
+ self.test_count_threading(step=5)
+
def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 8641c2f87e..0d6ff20489 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1,13 +1,14 @@
#include "Python.h"
-#include "pycore_call.h" // _PyObject_CallNoArgs()
-#include "pycore_ceval.h" // _PyEval_GetBuiltin()
-#include "pycore_long.h" // _PyLong_GetZero()
-#include "pycore_moduleobject.h" // _PyModule_GetState()
-#include "pycore_typeobject.h" // _PyType_GetModuleState()
-#include "pycore_object.h" // _PyObject_GC_TRACK()
-#include "pycore_tuple.h" // _PyTuple_ITEMS()
+#include "pycore_call.h" // _PyObject_CallNoArgs()
+#include "pycore_ceval.h" // _PyEval_GetBuiltin()
+#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
+#include "pycore_long.h" // _PyLong_GetZero()
+#include "pycore_moduleobject.h" // _PyModule_GetState()
+#include "pycore_typeobject.h" // _PyType_GetModuleState()
+#include "pycore_object.h" // _PyObject_GC_TRACK()
+#include "pycore_tuple.h" // _PyTuple_ITEMS()
-#include <stddef.h> // offsetof()
+#include <stddef.h> // offsetof()
/* Itertools module written and maintained
by Raymond D. Hettinger <python@rcn.com>
@@ -4037,7 +4038,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
Advances with: cnt += 1
- When count hits Y_SSIZE_T_MAX, switch to slow_mode.
+ When count hits PY_SSIZE_T_MAX, switch to slow_mode.
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
@@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
static PyObject *
count_next(countobject *lz)
{
+#ifndef Py_GIL_DISABLED
if (lz->cnt == PY_SSIZE_T_MAX)
return count_nextlong(lz);
return PyLong_FromSsize_t(lz->cnt++);
+#else
+ // free-threading version
+ // fast mode uses compare-exchange loop
+ // slow mode uses a critical section
+ PyObject *returned;
+ Py_ssize_t cnt;
+
+ cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
+ for (;;) {
+ if (cnt == PY_SSIZE_T_MAX) {
+ Py_BEGIN_CRITICAL_SECTION(lz);
+ returned = count_nextlong(lz);
+ Py_END_CRITICAL_SECTION();
+ return returned;
+ }
+ if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
+ return PyLong_FromSsize_t(cnt);
+ }
+ }
+#endif
}
static PyObject *
diff --git a/Tools/tsan/suppressions_free_threading.txt b/Tools/tsan/suppressions_free_threading.txt
index 0bb0183147..d5fcac61f0 100644
--- a/Tools/tsan/suppressions_free_threading.txt
+++ b/Tools/tsan/suppressions_free_threading.txt
@@ -47,7 +47,6 @@ race_top:_PyImport_AcquireLock
race_top:_Py_dict_lookup_threadsafe
race_top:_imp_release_lock
race_top:_multiprocessing_SemLock_acquire_impl
-race_top:count_next
race_top:dictiter_new
race_top:dictresize
race_top:insert_to_emptydict