From 7bac207afd17455b822856cd5e583b7d4188c018 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 18:58:52 +0200 Subject: [PATCH 01/18] fix: replace Python 2 iter().next() with next(iter()) In Python 3, iterator objects do not have a .next() method; the built-in next() function must be used instead. The call at line 355 of test_class_construction.py used the Python 2 pattern iter(...).next(), which would raise AttributeError if ever reached at runtime. Currently the test passes only because CQLEngineException is raised before .next() is called, but this is fragile: if the exception timing changes, the test would fail with AttributeError instead of the expected CQLEngineException. Replace with next(iter(...)) for correct Python 3 usage. --- .../model/test_class_construction.py | 207 ++++++++++-------- 1 file changed, 115 insertions(+), 92 deletions(-) diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py index df0a57d543..fae0d01105 100644 --- a/tests/integration/cqlengine/model/test_class_construction.py +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -16,7 +16,12 @@ import warnings from cassandra.cqlengine import columns, CQLEngineException -from cassandra.cqlengine.models import Model, ModelException, ModelDefinitionException, ColumnQueryEvaluator +from cassandra.cqlengine.models import ( + Model, + ModelException, + ModelDefinitionException, + ColumnQueryEvaluator, +) from cassandra.cqlengine.query import ModelQuerySet, DMLQuery from tests.integration.cqlengine.base import BaseCassEngTestCase @@ -35,19 +40,18 @@ def test_column_attributes_handled_correctly(self): """ class TestModel(Model): - - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + id = columns.UUID(primary_key=True, default=lambda: uuid4()) text = columns.Text() # check class attibutes - assert hasattr(TestModel, '_columns') - assert hasattr(TestModel, 'id') - assert hasattr(TestModel, 'text') + assert hasattr(TestModel, "_columns") + assert hasattr(TestModel, "id") + assert hasattr(TestModel, "text") # check instance attributes inst = TestModel() - assert hasattr(inst, 'id') - assert hasattr(inst, 'text') + assert hasattr(inst, "id") + assert hasattr(inst, "text") assert inst.id is not None assert inst.text is None @@ -57,35 +61,35 @@ def test_values_on_instantiation(self): """ class TestPerson(Model): - first_name = columns.Text(primary_key=True, default='kevin') - last_name = columns.Text(default='deldycke') + first_name = columns.Text(primary_key=True, default="kevin") + last_name = columns.Text(default="deldycke") # Check that defaults are available at instantiation. inst1 = TestPerson() - assert hasattr(inst1, 'first_name') - assert hasattr(inst1, 'last_name') - assert inst1.first_name == 'kevin' - assert inst1.last_name == 'deldycke' + assert hasattr(inst1, "first_name") + assert hasattr(inst1, "last_name") + assert inst1.first_name == "kevin" + assert inst1.last_name == "deldycke" # Check that values on instantiation overrides defaults. - inst2 = TestPerson(first_name='bob', last_name='joe') - assert inst2.first_name == 'bob' - assert inst2.last_name == 'joe' + inst2 = TestPerson(first_name="bob", last_name="joe") + assert inst2.first_name == "bob" + assert inst2.last_name == "joe" def test_db_map(self): """ Tests that the db_map is properly defined -the db_map allows columns """ - class WildDBNames(Model): - id = columns.UUID(primary_key=True, default=lambda:uuid4()) - content = columns.Text(db_field='words_and_whatnot') - numbers = columns.Integer(db_field='integers_etc') + class WildDBNames(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) + content = columns.Text(db_field="words_and_whatnot") + numbers = columns.Integer(db_field="integers_etc") db_map = WildDBNames._db_map - assert db_map['words_and_whatnot'] == 'content' - assert db_map['integers_etc'] == 'numbers' + assert db_map["words_and_whatnot"] == "content" + assert db_map["integers_etc"] == "numbers" def test_attempting_to_make_duplicate_column_names_fails(self): """ @@ -93,9 +97,10 @@ def test_attempting_to_make_duplicate_column_names_fails(self): """ with pytest.raises(ModelException, match=r".*more than once$"): + class BadNames(Model): words = columns.Text(primary_key=True) - content = columns.Text(db_field='words') + content = columns.Text(db_field="words") def test_column_ordering_is_preserved(self): """ @@ -103,18 +108,22 @@ def test_column_ordering_is_preserved(self): """ class Stuff(Model): - - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + id = columns.UUID(primary_key=True, default=lambda: uuid4()) words = columns.Text() content = columns.Text() numbers = columns.Integer() - assert [x for x in Stuff._columns.keys()] == ['id', 'words', 'content', 'numbers'] + assert [x for x in Stuff._columns.keys()] == [ + "id", + "words", + "content", + "numbers", + ] def test_exception_raised_when_creating_class_without_pk(self): with pytest.raises(ModelDefinitionException): - class TestModel(Model): + class TestModel(Model): count = columns.Integer() text = columns.Text(required=False) @@ -122,9 +131,9 @@ def test_value_managers_are_keeping_model_instances_isolated(self): """ Tests that instance value managers are isolated from other instances """ - class Stuff(Model): - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class Stuff(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) num = columns.Integer() inst1 = Stuff(num=5) @@ -138,55 +147,55 @@ def test_superclass_fields_are_inherited(self): """ Tests that fields defined on the super class are inherited properly """ - class TestModel(Model): - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class TestModel(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) text = columns.Text() class InheritedModel(TestModel): numbers = columns.Integer() - assert 'text' in InheritedModel._columns - assert 'numbers' in InheritedModel._columns + assert "text" in InheritedModel._columns + assert "numbers" in InheritedModel._columns def test_column_family_name_generation(self): - """ Tests that auto column family name generation works as expected """ - class TestModel(Model): + """Tests that auto column family name generation works as expected""" - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class TestModel(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) text = columns.Text() - assert TestModel.column_family_name(include_keyspace=False) == 'test_model' + assert TestModel.column_family_name(include_keyspace=False) == "test_model" def test_partition_keys(self): """ Test compound partition key definition """ - class ModelWithPartitionKeys(Model): - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class ModelWithPartitionKeys(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) c1 = columns.Text(primary_key=True) p1 = columns.Text(partition_key=True) p2 = columns.Text(partition_key=True) cols = ModelWithPartitionKeys._columns - assert cols['c1'].primary_key - assert not cols['c1'].partition_key + assert cols["c1"].primary_key + assert not cols["c1"].partition_key - assert cols['p1'].primary_key - assert cols['p1'].partition_key - assert cols['p2'].primary_key - assert cols['p2'].partition_key + assert cols["p1"].primary_key + assert cols["p1"].partition_key + assert cols["p2"].primary_key + assert cols["p2"].partition_key - obj = ModelWithPartitionKeys(p1='a', p2='b') - assert obj.pk == ('a', 'b') + obj = ModelWithPartitionKeys(p1="a", p2="b") + assert obj.pk == ("a", "b") def test_del_attribute_is_assigned_properly(self): - """ Tests that columns that can be deleted have the del attribute """ - class DelModel(Model): + """Tests that columns that can be deleted have the del attribute""" - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class DelModel(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) key = columns.Integer(primary_key=True) data = columns.Integer(required=False) @@ -196,15 +205,13 @@ class DelModel(Model): del model.key def test_does_not_exist_exceptions_are_not_shared_between_model(self): - """ Tests that DoesNotExist exceptions are not the same exception between models """ + """Tests that DoesNotExist exceptions are not the same exception between models""" class Model1(Model): - - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + id = columns.UUID(primary_key=True, default=lambda: uuid4()) class Model2(Model): - - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + id = columns.UUID(primary_key=True, default=lambda: uuid4()) try: raise Model1.DoesNotExist @@ -215,10 +222,10 @@ class Model2(Model): pass def test_does_not_exist_inherits_from_superclass(self): - """ Tests that a DoesNotExist exception can be caught by it's parent class DoesNotExist """ - class Model1(Model): + """Tests that a DoesNotExist exception can be caught by it's parent class DoesNotExist""" - id = columns.UUID(primary_key=True, default=lambda:uuid4()) + class Model1(Model): + id = columns.UUID(primary_key=True, default=lambda: uuid4()) class Model2(Model1): pass @@ -233,6 +240,7 @@ class Model2(Model1): def test_abstract_model_keyspace_warning_is_skipped(self): with warnings.catch_warnings(record=True) as warn: + class NoKeyspace(Model): __abstract__ = True key = columns.UUID(primary_key=True) @@ -241,30 +249,33 @@ class NoKeyspace(Model): class TestManualTableNaming(BaseCassEngTestCase): - class RenamedTest(Model): - __keyspace__ = 'whatever' - __table_name__ = 'manual_name' + __keyspace__ = "whatever" + __table_name__ = "manual_name" id = columns.UUID(primary_key=True) data = columns.Text() def test_proper_table_naming(self): - assert self.RenamedTest.column_family_name(include_keyspace=False) == 'manual_name' - assert self.RenamedTest.column_family_name(include_keyspace=True) == 'whatever.manual_name' + assert ( + self.RenamedTest.column_family_name(include_keyspace=False) == "manual_name" + ) + assert ( + self.RenamedTest.column_family_name(include_keyspace=True) + == "whatever.manual_name" + ) class TestManualTableNamingCaseSensitive(BaseCassEngTestCase): - class RenamedCaseInsensitiveTest(Model): - __keyspace__ = 'whatever' - __table_name__ = 'Manual_Name' + __keyspace__ = "whatever" + __table_name__ = "Manual_Name" id = columns.UUID(primary_key=True) class RenamedCaseSensitiveTest(Model): - __keyspace__ = 'whatever' - __table_name__ = 'Manual_Name' + __keyspace__ = "whatever" + __table_name__ = "Manual_Name" __table_name_case_sensitive__ = True id = columns.UUID(primary_key=True) @@ -279,8 +290,14 @@ def test_proper_table_naming_case_insensitive(self): @test_category object_mapper """ - assert self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=False) == 'manual_name' - assert self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=True) == 'whatever.manual_name' + assert ( + self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=False) + == "manual_name" + ) + assert ( + self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=True) + == "whatever.manual_name" + ) def test_proper_table_naming_case_sensitive(self): """ @@ -293,8 +310,14 @@ def test_proper_table_naming_case_sensitive(self): @test_category object_mapper """ - assert self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=False) == '"Manual_Name"' - assert self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=True) == 'whatever."Manual_Name"' + assert ( + self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=False) + == '"Manual_Name"' + ) + assert ( + self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=True) + == 'whatever."Manual_Name"' + ) class AbstractModel(Model): @@ -307,7 +330,6 @@ class ConcreteModel(AbstractModel): class AbstractModelWithCol(Model): - __abstract__ = True pkey = columns.Integer(primary_key=True) @@ -324,45 +346,46 @@ class AbstractModelWithFullCols(Model): class TestAbstractModelClasses(BaseCassEngTestCase): - def test_id_field_is_not_created(self): - """ Tests that an id field is not automatically generated on abstract classes """ - assert not hasattr(AbstractModel, 'id') - assert not hasattr(AbstractModelWithCol, 'id') + """Tests that an id field is not automatically generated on abstract classes""" + assert not hasattr(AbstractModel, "id") + assert not hasattr(AbstractModelWithCol, "id") def test_id_field_is_not_created_on_subclass(self): - assert not hasattr(ConcreteModel, 'id') + assert not hasattr(ConcreteModel, "id") def test_abstract_attribute_is_not_inherited(self): - """ Tests that __abstract__ attribute is not inherited """ + """Tests that __abstract__ attribute is not inherited""" assert not ConcreteModel.__abstract__ assert not ConcreteModelWithCol.__abstract__ def test_attempting_to_save_abstract_model_fails(self): - """ Attempting to save a model from an abstract model should fail """ + """Attempting to save a model from an abstract model should fail""" with pytest.raises(CQLEngineException): AbstractModelWithFullCols.create(pkey=1, data=2) def test_attempting_to_create_abstract_table_fails(self): - """ Attempting to create a table from an abstract model should fail """ + """Attempting to create a table from an abstract model should fail""" from cassandra.cqlengine.management import sync_table + with pytest.raises(CQLEngineException): sync_table(AbstractModelWithFullCols) def test_attempting_query_on_abstract_model_fails(self): - """ Tests attempting to execute query with an abstract model fails """ + """Tests attempting to execute query with an abstract model fails""" with pytest.raises(CQLEngineException): - iter(AbstractModelWithFullCols.objects(pkey=5)).next() + next(iter(AbstractModelWithFullCols.objects(pkey=5))) def test_abstract_columns_are_inherited(self): - """ Tests that columns defined in the abstract class are inherited into the concrete class """ - assert hasattr(ConcreteModelWithCol, 'pkey') + """Tests that columns defined in the abstract class are inherited into the concrete class""" + assert hasattr(ConcreteModelWithCol, "pkey") assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator) - assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column) + assert isinstance(ConcreteModelWithCol._columns["pkey"], columns.Column) def test_concrete_class_table_creation_cycle(self): - """ Tests that models with inherited abstract classes can be created, and have io performed """ + """Tests that models with inherited abstract classes can be created, and have io performed""" from cassandra.cqlengine.management import sync_table, drop_table + sync_table(ConcreteModelWithCol) w1 = ConcreteModelWithCol.create(pkey=5, data=6) @@ -380,9 +403,10 @@ def test_concrete_class_table_creation_cycle(self): class TestCustomQuerySet(BaseCassEngTestCase): - """ Tests overriding the default queryset class """ + """Tests overriding the default queryset class""" - class TestException(Exception): pass + class TestException(Exception): + pass def test_overriding_queryset(self): @@ -397,7 +421,7 @@ class CQModel(Model): data = columns.Text() with pytest.raises(self.TestException): - CQModel.create(part=uuid4(), data='s') + CQModel.create(part=uuid4(), data="s") def test_overriding_dmlqueryset(self): @@ -406,7 +430,6 @@ def save(iself): raise self.TestException class CDQModel(Model): - __dmlquery__ = DMLQ part = columns.UUID(primary_key=True) data = columns.Text() From 5169599ca8fe8340ff4e202058a5c7140341d193 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 18:59:44 +0200 Subject: [PATCH 02/18] fix: replace dead UnicodeDecodeError handlers with isinstance checks In Python 3, calling str.encode() can only raise UnicodeEncodeError, never UnicodeDecodeError. The except UnicodeDecodeError branches in AsciiType.serialize and UTF8Type.serialize were leftover from Python 2, where str.encode() could trigger an implicit decode of a byte string. These dead except branches silently masked the intended behavior. In Python 3, if the input is already bytes there is no .encode() to call, so the original code would raise AttributeError rather than returning the value as-is. Replace the try/except pattern with explicit isinstance(var, bytes) checks, which correctly handles both str and bytes inputs on Python 3. --- cassandra/cqltypes.py | 620 ++++++++++++++++++++++++++---------------- 1 file changed, 386 insertions(+), 234 deletions(-) diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index d33e5fceb8..e043a05015 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -43,21 +43,41 @@ import sys from uuid import UUID -from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, - uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, - int32_pack, int32_unpack, int64_pack, int64_unpack, - float_pack, float_unpack, double_pack, double_unpack, - varint_pack, varint_unpack, point_be, point_le, - vints_pack, vints_unpack, uvint_unpack, uvint_pack) +from cassandra.marshal import ( + int8_pack, + int8_unpack, + int16_pack, + int16_unpack, + uint16_pack, + uint16_unpack, + uint32_pack, + uint32_unpack, + int32_pack, + int32_unpack, + int64_pack, + int64_unpack, + float_pack, + float_unpack, + double_pack, + double_unpack, + varint_pack, + varint_unpack, + point_be, + point_le, + vints_pack, + vints_unpack, + uvint_unpack, + uvint_pack, +) from cassandra import util _little_endian_flag = 1 # we always serialize LE import ipaddress -apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' +apache_cassandra_type_prefix = "org.apache.cassandra.db.marshal." -cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' -cql_empty_type = 'empty' +cassandra_empty_type = "org.apache.cassandra.db.marshal.EmptyType" +cql_empty_type = "empty" log = logging.getLogger(__name__) @@ -66,12 +86,12 @@ def _name_from_hex_string(encoded_name): bin_str = unhexlify(encoded_name) - return bin_str.decode('ascii') + return bin_str.decode("ascii") def trim_if_startswith(s, prefix): if s.startswith(prefix): - return s[len(prefix):] + return s[len(prefix) :] return s @@ -79,11 +99,13 @@ def trim_if_startswith(s, prefix): _cqltypes = {} -cql_type_scanner = re.Scanner(( - ('frozen', None), - (r'[a-zA-Z0-9_]+', lambda s, t: t), - (r'[\s,<>]', None), -)) +cql_type_scanner = re.Scanner( + ( + ("frozen", None), + (r"[a-zA-Z0-9_]+", lambda s, t: t), + (r"[\s,<>]", None), + ) +) def cql_types_from_string(cql_type): @@ -102,20 +124,22 @@ class CassandraTypeType(type): """ def __new__(metacls, name, bases, dct): - dct.setdefault('cassname', name) + dct.setdefault("cassname", name) cls = type.__new__(metacls, name, bases, dct) - if not name.startswith('_'): + if not name.startswith("_"): _casstypes[name] = cls if not cls.typename.startswith(apache_cassandra_type_prefix): _cqltypes[cls.typename] = cls return cls -casstype_scanner = re.Scanner(( - (r'[()]', lambda s, t: t), - (r'[a-zA-Z0-9_.:=>]+', lambda s, t: t), - (r'[\s,]', None), -)) +casstype_scanner = re.Scanner( + ( + (r"[()]", lambda s, t: t), + (r"[a-zA-Z0-9_.:=>]+", lambda s, t: t), + (r"[\s,]", None), + ) +) def cqltype_to_python(cql_string): @@ -125,16 +149,18 @@ def cqltype_to_python(cql_string): int -> ['int'] frozen> -> ['frozen', ['tuple', ['text', 'int']]] """ - scanner = re.Scanner(( - (r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)), - (r'<', lambda s, t: ', ['), - (r'>', lambda s, t: ']'), - (r'[, ]', lambda s, t: t), - (r'".*?"', lambda s, t: "'{}'".format(t)), - )) + scanner = re.Scanner( + ( + (r"[a-zA-Z0-9_]+", lambda s, t: "'{}'".format(t)), + (r"<", lambda s, t: ", ["), + (r">", lambda s, t: "]"), + (r"[, ]", lambda s, t: t), + (r'".*?"', lambda s, t: "'{}'".format(t)), + ) + ) scanned_tokens = scanner.scan(cql_string)[0] - hierarchy = ast.literal_eval(''.join(scanned_tokens)) + hierarchy = ast.literal_eval("".join(scanned_tokens)) return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy) @@ -145,18 +171,20 @@ def python_to_cqltype(types): ['int'] -> int ['frozen', ['tuple', ['text', 'int']]] -> frozen> """ - scanner = re.Scanner(( - (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), - (r'^\[', lambda s, t: None), - (r'\]$', lambda s, t: None), - (r',\s*\[', lambda s, t: '<'), - (r'\]', lambda s, t: '>'), - (r'[, ]', lambda s, t: t), - (r'\'".*?"\'', lambda s, t: t[1:-1]), - )) + scanner = re.Scanner( + ( + (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), + (r"^\[", lambda s, t: None), + (r"\]$", lambda s, t: None), + (r",\s*\[", lambda s, t: "<"), + (r"\]", lambda s, t: ">"), + (r"[, ]", lambda s, t: t), + (r'\'".*?"\'', lambda s, t: t[1:-1]), + ) + ) scanned_tokens = scanner.scan(repr(types))[0] - cql = ''.join(scanned_tokens).replace('\\\\', '\\') + cql = "".join(scanned_tokens).replace("\\\\", "\\") return cql @@ -166,10 +194,13 @@ def _strip_frozen_from_python(types): Example: ['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']] """ - while 'frozen' in types: - index = types.index('frozen') - types = types[:index] + types[index + 1] + types[index + 2:] - new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types] + while "frozen" in types: + index = types.index("frozen") + types = types[:index] + types[index + 1] + types[index + 2 :] + new_types = [ + _strip_frozen_from_python(item) if isinstance(item, list) else item + for item in types + ] return new_types @@ -211,15 +242,15 @@ def parse_casstype_args(typestring): # use a stack of (types, names) lists args = [([], [])] for tok in tokens: - if tok == '(': + if tok == "(": args.append(([], [])) - elif tok == ')': + elif tok == ")": types, names = args.pop() prev_types, prev_names = args[-1] prev_types[-1] = prev_types[-1].apply_parameters(types, names) else: types, names = args[-1] - parts = re.split(':|=>', tok) + parts = re.split(":|=>", tok) tok = parts.pop() if parts: names.append(parts[0]) @@ -235,6 +266,7 @@ def parse_casstype_args(typestring): # return the first (outer) type, which will have all parameters applied return args[0][0][0] + def lookup_casstype(casstype): """ Given a Cassandra type as a string (possibly including parameters), hand @@ -260,12 +292,14 @@ def is_reversed_casstype(data_type): class EmptyValue(object): - """ See _CassandraType.support_empty_values """ + """See _CassandraType.support_empty_values""" def __str__(self): return "EMPTY" + __repr__ = __str__ + EMPTY = EmptyValue() @@ -288,7 +322,7 @@ class _CassandraType(object, metaclass=CassandraTypeType): """ def __repr__(self): - return '<%s>' % (self.cql_parameterized_type()) + return "<%s>" % (self.cql_parameterized_type()) @classmethod def from_binary(cls, byts, protocol_version): @@ -310,7 +344,7 @@ def to_binary(cls, val, protocol_version): more information. This method differs in that if None is passed in, the result is the empty string. """ - return b'' if val is None else cls.serialize(val, protocol_version) + return b"" if val is None else cls.serialize(val, protocol_version) @staticmethod def deserialize(byts, protocol_version): @@ -349,12 +383,14 @@ def cass_parameterized_type_with(cls, subtypes, full=False): 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' """ cname = cls.cassname - if full and '.' not in cname: + if full and "." not in cname: cname = apache_cassandra_type_prefix + cname if not subtypes: return cname - sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes) - return '%s(%s)' % (cname, sublist) + sublist = ", ".join( + styp.cass_parameterized_type(full=full) for styp in subtypes + ) + return "%s(%s)" % (cname, sublist) @classmethod def apply_parameters(cls, subtypes, names=None): @@ -368,11 +404,17 @@ def apply_parameters(cls, subtypes, names=None): `subtypes` will be a sequence of CassandraTypes. If provided, `names` will be an equally long sequence of column names or Nones. """ - if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: - raise ValueError("%s types require %d subtypes (%d given)" - % (cls.typename, cls.num_subtypes, len(subtypes))) + if cls.num_subtypes != "UNKNOWN" and len(subtypes) != cls.num_subtypes: + raise ValueError( + "%s types require %d subtypes (%d given)" + % (cls.typename, cls.num_subtypes, len(subtypes)) + ) newname = cls.cass_parameterized_type_with(subtypes) - return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) + return type( + newname, + (cls,), + {"subtypes": subtypes, "cassname": cls.cassname, "fieldnames": names}, + ) @classmethod def cql_parameterized_type(cls): @@ -382,7 +424,10 @@ def cql_parameterized_type(cls): """ if not cls.subtypes: return cls.typename - return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes)) + return "%s<%s>" % ( + cls.typename, + ", ".join(styp.cql_parameterized_type() for styp in cls.subtypes), + ) @classmethod def cass_parameterized_type(cls, full=False): @@ -396,23 +441,24 @@ def cass_parameterized_type(cls, full=False): def serial_size(cls): return None + # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc CassandraType = _CassandraType class _UnrecognizedType(_CassandraType): - num_subtypes = 'UNKNOWN' + num_subtypes = "UNKNOWN" def mkUnrecognizedType(casstypename): - return CassandraTypeType(casstypename, - (_UnrecognizedType,), - {'typename': "'%s'" % casstypename}) + return CassandraTypeType( + casstypename, (_UnrecognizedType,), {"typename": "'%s'" % casstypename} + ) class BytesType(_CassandraType): - typename = 'blob' + typename = "blob" empty_binary_ok = True @staticmethod @@ -421,13 +467,13 @@ def serialize(val, protocol_version): class DecimalType(_CassandraType): - typename = 'decimal' + typename = "decimal" @staticmethod def deserialize(byts, protocol_version): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) - return Decimal('%de%d' % (unscaled, -scale)) + return Decimal("%de%d" % (unscaled, -scale)) @staticmethod def serialize(dec, protocol_version): @@ -438,7 +484,7 @@ def serialize(dec, protocol_version): sign, digits, exponent = Decimal(dec).as_tuple() except Exception: raise TypeError("Invalid type for Decimal value: %r", dec) - unscaled = int(''.join([str(digit) for digit in digits])) + unscaled = int("".join([str(digit) for digit in digits])) if sign: unscaled *= -1 scale = int32_pack(-exponent) @@ -447,7 +493,7 @@ def serialize(dec, protocol_version): class UUIDType(_CassandraType): - typename = 'uuid' + typename = "uuid" @staticmethod def deserialize(byts, protocol_version): @@ -464,8 +510,9 @@ def serialize(uuid, protocol_version): def serial_size(cls): return 16 + class BooleanType(_CassandraType): - typename = 'boolean' + typename = "boolean" @staticmethod def deserialize(byts, protocol_version): @@ -479,8 +526,9 @@ def serialize(truth, protocol_version): def serial_size(cls): return 1 + class ByteType(_CassandraType): - typename = 'tinyint' + typename = "tinyint" @staticmethod def deserialize(byts, protocol_version): @@ -492,23 +540,22 @@ def serialize(byts, protocol_version): class AsciiType(_CassandraType): - typename = 'ascii' + typename = "ascii" empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): - return byts.decode('ascii') + return byts.decode("ascii") @staticmethod def serialize(var, protocol_version): - try: - return var.encode('ascii') - except UnicodeDecodeError: + if isinstance(var, bytes): return var + return var.encode("ascii") class FloatType(_CassandraType): - typename = 'float' + typename = "float" @staticmethod def deserialize(byts, protocol_version): @@ -522,8 +569,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 4 + class DoubleType(_CassandraType): - typename = 'double' + typename = "double" @staticmethod def deserialize(byts, protocol_version): @@ -537,8 +585,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 8 + class LongType(_CassandraType): - typename = 'bigint' + typename = "bigint" @staticmethod def deserialize(byts, protocol_version): @@ -552,8 +601,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 8 + class Int32Type(_CassandraType): - typename = 'int' + typename = "int" @staticmethod def deserialize(byts, protocol_version): @@ -567,8 +617,9 @@ def serialize(byts, protocol_version): def serial_size(cls): return 4 + class IntegerType(_CassandraType): - typename = 'varint' + typename = "varint" @staticmethod def deserialize(byts, protocol_version): @@ -580,7 +631,7 @@ def serialize(byts, protocol_version): class InetAddressType(_CassandraType): - typename = 'inet' + typename = "inet" @staticmethod def deserialize(byts, protocol_version): @@ -594,7 +645,7 @@ def deserialize(byts, protocol_version): @staticmethod def serialize(addr, protocol_version): try: - if ':' in addr: + if ":" in addr: return util.inet_pton(socket.AF_INET6, addr) else: # util.inet_pton could also handle, but this is faster @@ -607,26 +658,27 @@ def serialize(addr, protocol_version): class CounterColumnType(LongType): - typename = 'counter' + typename = "counter" + cql_timestamp_formats = ( - '%Y-%m-%d %H:%M', - '%Y-%m-%d %H:%M:%S', - '%Y-%m-%dT%H:%M', - '%Y-%m-%dT%H:%M:%S', - '%Y-%m-%d' + "%Y-%m-%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d", ) _have_warned_about_timestamps = False class DateType(_CassandraType): - typename = 'timestamp' + typename = "timestamp" @staticmethod def interpret_datestring(val): - if val[-5] in ('+', '-'): - offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') + if val[-5] in ("+", "-"): + offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + "1") val = val[:-5] else: offset = -time.timezone @@ -650,14 +702,16 @@ def serialize(v, protocol_version): try: # v is datetime timestamp_seconds = calendar.timegm(v.utctimetuple()) - timestamp = timestamp_seconds * 1000 + getattr(v, 'microsecond', 0) // 1000 + timestamp = timestamp_seconds * 1000 + getattr(v, "microsecond", 0) // 1000 except AttributeError: try: timestamp = calendar.timegm(v.timetuple()) * 1000 except AttributeError: # Ints and floats are valid timestamps too if type(v) not in _number_types: - raise TypeError('DateType arguments must be a datetime, date, or timestamp') + raise TypeError( + "DateType arguments must be a datetime, date, or timestamp" + ) timestamp = v return int64_pack(int(timestamp)) @@ -666,12 +720,13 @@ def serialize(v, protocol_version): def serial_size(cls): return 8 + class TimestampType(DateType): pass class TimeUUIDType(DateType): - typename = 'timeuuid' + typename = "timeuuid" def my_timestamp(self): return util.unix_time_from_uuid1(self.val) @@ -691,14 +746,15 @@ def serialize(timeuuid, protocol_version): def serial_size(cls): return 16 + class SimpleDateType(_CassandraType): - typename = 'date' + typename = "date" date_format = "%Y-%m-%d" # Values of the 'date'` type are encoded as 32-bit unsigned integers # representing a number of days with epoch (January 1st, 1970) at the center of the # range (2^31). - EPOCH_OFFSET_DAYS = 2 ** 31 + EPOCH_OFFSET_DAYS = 2**31 @staticmethod def deserialize(byts, protocol_version): @@ -720,7 +776,7 @@ def serialize(val, protocol_version): class ShortType(_CassandraType): - typename = 'smallint' + typename = "smallint" @staticmethod def deserialize(byts, protocol_version): @@ -730,13 +786,14 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return int16_pack(byts) + class TimeType(_CassandraType): - typename = 'time' + typename = "time" # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as # variable size... and we have to match what the server expects since the server # uses that specification to encode data of that type. - #@classmethod - #def serial_size(cls): + # @classmethod + # def serial_size(cls): # return 8 @staticmethod @@ -753,7 +810,7 @@ def serialize(val, protocol_version): class DurationType(_CassandraType): - typename = 'duration' + typename = "duration" @staticmethod def deserialize(byts, protocol_version): @@ -765,65 +822,65 @@ def serialize(duration, protocol_version): try: m, d, n = duration.months, duration.days, duration.nanoseconds except AttributeError: - raise TypeError('DurationType arguments must be a Duration.') + raise TypeError("DurationType arguments must be a Duration.") return vints_pack([m, d, n]) class UTF8Type(_CassandraType): - typename = 'text' + typename = "text" empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): - return byts.decode('utf8') + return byts.decode("utf8") @staticmethod def serialize(ustr, protocol_version): - try: - return ustr.encode('utf-8') - except UnicodeDecodeError: - # already utf-8 + if isinstance(ustr, bytes): return ustr + return ustr.encode("utf-8") class VarcharType(UTF8Type): - typename = 'varchar' + typename = "varchar" class _ParameterizedType(_CassandraType): - num_subtypes = 'UNKNOWN' + num_subtypes = "UNKNOWN" @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: - raise NotImplementedError("can't deserialize unparameterized %s" - % cls.typename) + raise NotImplementedError( + "can't deserialize unparameterized %s" % cls.typename + ) return cls.deserialize_safe(byts, protocol_version) @classmethod def serialize(cls, val, protocol_version): if not cls.subtypes: - raise NotImplementedError("can't serialize unparameterized %s" - % cls.typename) + raise NotImplementedError( + "can't serialize unparameterized %s" % cls.typename + ) return cls.serialize_safe(val, protocol_version) class _SimpleParameterizedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes length = 4 numelements = int32_unpack(byts[:length]) p = length result = [] inner_proto = max(3, protocol_version) for _ in range(numelements): - itemlen = int32_unpack(byts[p:p + length]) + itemlen = int32_unpack(byts[p : p + length]) p += length if itemlen < 0: result.append(None) else: - item = byts[p:p + itemlen] + item = byts[p : p + itemlen] p += itemlen result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @@ -833,7 +890,7 @@ def serialize_safe(cls, items, protocol_version): if isinstance(items, str): raise TypeError("Received a string for a type that expects a sequence") - subtype, = cls.subtypes + (subtype,) = cls.subtypes buf = io.BytesIO() buf.write(int32_pack(len(items))) inner_proto = max(3, protocol_version) @@ -848,19 +905,19 @@ def serialize_safe(cls, items, protocol_version): class ListType(_SimpleParameterizedType): - typename = 'list' + typename = "list" num_subtypes = 1 adapter = list class SetType(_SimpleParameterizedType): - typename = 'set' + typename = "set" num_subtypes = 1 adapter = util.sortedset class MapType(_ParameterizedType): - typename = 'map' + typename = "map" num_subtypes = 2 @classmethod @@ -872,22 +929,22 @@ def deserialize_safe(cls, byts, protocol_version): themap = util.OrderedMapSerializedKey(key_type, protocol_version) inner_proto = max(3, protocol_version) for _ in range(numelements): - key_len = int32_unpack(byts[p:p + length]) + key_len = int32_unpack(byts[p : p + length]) p += length if key_len < 0: keybytes = None key = None else: - keybytes = byts[p:p + key_len] + keybytes = byts[p : p + key_len] p += key_len key = key_type.from_binary(keybytes, inner_proto) - val_len = int32_unpack(byts[p:p + length]) + val_len = int32_unpack(byts[p : p + length]) p += length if val_len < 0: val = None else: - valbytes = byts[p:p + val_len] + valbytes = byts[p : p + val_len] p += val_len val = value_type.from_binary(valbytes, inner_proto) @@ -921,7 +978,7 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): - typename = 'tuple' + typename = "tuple" @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -931,10 +988,10 @@ def deserialize_safe(cls, byts, protocol_version): for col_type in cls.subtypes: if p == len(byts): break - itemlen = int32_unpack(byts[p:p + 4]) + itemlen = int32_unpack(byts[p : p + 4]) p += 4 if itemlen >= 0: - item = byts[p:p + itemlen] + item = byts[p : p + itemlen] p += itemlen else: item = None @@ -951,8 +1008,10 @@ def deserialize_safe(cls, byts, protocol_version): @classmethod def serialize_safe(cls, val, protocol_version): if len(val) > len(cls.subtypes): - raise ValueError("Expected %d items in a tuple, but got %d: %s" % - (len(cls.subtypes), len(val), val)) + raise ValueError( + "Expected %d items in a tuple, but got %d: %s" + % (len(cls.subtypes), len(val), val) + ) proto_version = max(3, protocol_version) buf = io.BytesIO() @@ -967,8 +1026,10 @@ def serialize_safe(cls, val, protocol_version): @classmethod def cql_parameterized_type(cls): - subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) - return 'frozen>' % (subtypes_string,) + subtypes_string = ", ".join( + sub.cql_parameterized_type() for sub in cls.subtypes + ) + return "frozen>" % (subtypes_string,) class UserType(TupleType): @@ -982,14 +1043,26 @@ def make_udt_class(cls, keyspace, udt_name, field_names, field_types): assert len(field_names) == len(field_types) instance = cls._cache.get((keyspace, udt_name)) - if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: - instance = type(udt_name, (cls,), {'subtypes': field_types, - 'cassname': cls.cassname, - 'typename': udt_name, - 'fieldnames': field_names, - 'keyspace': keyspace, - 'mapped_class': None, - 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) + if ( + not instance + or instance.fieldnames != field_names + or instance.subtypes != field_types + ): + instance = type( + udt_name, + (cls,), + { + "subtypes": field_types, + "cassname": cls.cassname, + "typename": udt_name, + "fieldnames": field_names, + "keyspace": keyspace, + "mapped_class": None, + "tuple_type": cls._make_registered_udt_namedtuple( + keyspace, udt_name, field_names + ), + }, + ) cls._cache[(keyspace, udt_name)] = instance return instance @@ -1002,9 +1075,13 @@ def evict_udt_class(cls, keyspace, udt_name): @classmethod def apply_parameters(cls, subtypes, names): - keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back + keyspace = subtypes[ + 0 + ].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back udt_name = _name_from_hex_string(subtypes[1].cassname) - field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) + field_names = tuple( + _name_from_hex_string(encoded_name) for encoded_name in names[2:] + ) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:])) @classmethod @@ -1032,7 +1109,9 @@ def serialize_safe(cls, val, protocol_version): except TypeError: item = getattr(val, fieldname, None) if item is None and not hasattr(val, fieldname): - log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") + log.warning( + f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}" + ) if item is not None: packed_item = subtype.to_binary(item, proto_version) @@ -1061,15 +1140,21 @@ def _make_udt_tuple_type(cls, name, field_names): t = namedtuple(name, field_names) except ValueError: try: - t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) - log.warning("could not create a namedtuple for '%s' because one or more " - "field names are not valid Python identifiers (%s); " - "returning positionally-named fields" % (name, field_names)) + t = namedtuple( + name, util._positional_rename_invalid_identifiers(field_names) + ) + log.warning( + "could not create a namedtuple for '%s' because one or more " + "field names are not valid Python identifiers (%s); " + "returning positionally-named fields" % (name, field_names) + ) except ValueError: t = None - log.warning("could not create a namedtuple for '%s' because the name is " - "not a valid Python identifier; will return tuples in " - "its place" % (name,)) + log.warning( + "could not create a namedtuple for '%s' because the name is " + "not a valid Python identifier; will return tuples in " + "its place" % (name,) + ) return t @@ -1093,10 +1178,10 @@ def deserialize_safe(cls, byts, protocol_version): break element_length = uint16_unpack(byts[:2]) - element = byts[2:2 + element_length] + element = byts[2 : 2 + element_length] # skip element length, element, and the EOC (one byte) - byts = byts[2 + element_length + 1:] + byts = byts[2 + element_length + 1 :] result.append(subtype.from_binary(element, protocol_version)) return tuple(result) @@ -1107,7 +1192,10 @@ class DynamicCompositeType(_ParameterizedType): @classmethod def cql_parameterized_type(cls): - sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) + sublist = ", ".join( + "%s=>%s" % (alias, typ.cass_parameterized_type(full=True)) + for alias, typ in zip(cls.fieldnames, cls.subtypes) + ) return "'%s(%s)'" % (cls.typename, sublist) @@ -1117,6 +1205,7 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ + typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" @@ -1126,12 +1215,12 @@ class ReversedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.to_binary(val, protocol_version) @@ -1141,12 +1230,12 @@ class FrozenType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): - subtype, = cls.subtypes + (subtype,) = cls.subtypes return subtype.to_binary(val, protocol_version) @@ -1177,9 +1266,9 @@ class WKBGeometryType(object): class PointType(CassandraType): - typename = 'PointType' + typename = "PointType" - _type = struct.pack('[[]] type_ = int8_unpack(byts[0:1]) - if type_ in (BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), - BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN)): + if type_ in ( + BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), + BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN), + ): time0 = precision0 = None else: time0 = int64_unpack(byts[1:9]) @@ -1348,32 +1470,34 @@ def deserialize(cls, byts, protocol_version): if time0 is not None: date_range_bound0 = util.DateRangeBound( - time0, - cls._decode_precision(precision0) + time0, cls._decode_precision(precision0) ) if time1 is not None: date_range_bound1 = util.DateRangeBound( - time1, - cls._decode_precision(precision1) + time1, cls._decode_precision(precision1) ) if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE): return util.DateRange(value=date_range_bound0) if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): - return util.DateRange(lower_bound=date_range_bound0, - upper_bound=date_range_bound1) + return util.DateRange( + lower_bound=date_range_bound0, upper_bound=date_range_bound1 + ) if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_HIGH): - return util.DateRange(lower_bound=date_range_bound0, - upper_bound=util.OPEN_BOUND) + return util.DateRange( + lower_bound=date_range_bound0, upper_bound=util.OPEN_BOUND + ) if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_LOW): - return util.DateRange(lower_bound=util.OPEN_BOUND, - upper_bound=date_range_bound0) + return util.DateRange( + lower_bound=util.OPEN_BOUND, upper_bound=date_range_bound0 + ) if type_ == BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE): - return util.DateRange(lower_bound=util.OPEN_BOUND, - upper_bound=util.OPEN_BOUND) + return util.DateRange( + lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND + ) if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN): return util.DateRange(value=util.OPEN_BOUND) - raise ValueError('Could not deserialize %r' % (byts,)) + raise ValueError("Could not deserialize %r" % (byts,)) @classmethod def serialize(cls, v, protocol_version): @@ -1384,8 +1508,8 @@ def serialize(cls, v, protocol_version): value = v.value except AttributeError: raise ValueError( - '%s.serialize expects an object with a value attribute; got' - '%r' % (cls.__name__, v) + "%s.serialize expects an object with a value attribute; got" + "%r" % (cls.__name__, v) ) if value is None: @@ -1393,8 +1517,8 @@ def serialize(cls, v, protocol_version): lower_bound, upper_bound = v.lower_bound, v.upper_bound except AttributeError: raise ValueError( - '%s.serialize expects an object with lower_bound and ' - 'upper_bound attributes; got %r' % (cls.__name__, v) + "%s.serialize expects an object with lower_bound and " + "upper_bound attributes; got %r" % (cls.__name__, v) ) if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND: bound_kind = BoundKind.BOTH_OPEN_RANGE @@ -1415,9 +1539,7 @@ def serialize(cls, v, protocol_version): bounds = (value,) if bound_kind is None: - raise ValueError( - 'Cannot serialize %r; could not find bound kind' % (v,) - ) + raise ValueError("Cannot serialize %r; could not find bound kind" % (v,)) buf.write(int8_pack(BoundKind.to_int(bound_kind))) for bound in bounds: @@ -1426,22 +1548,29 @@ def serialize(cls, v, protocol_version): return buf.getvalue() + class VectorType(_CassandraType): - typename = 'org.apache.cassandra.db.marshal.VectorType' + typename = "org.apache.cassandra.db.marshal.VectorType" vector_size = 0 subtype = None @classmethod def serial_size(cls): serialized_size = cls.subtype.serial_size() - return cls.vector_size * serialized_size if serialized_size is not None else None + return ( + cls.vector_size * serialized_size if serialized_size is not None else None + ) @classmethod def apply_parameters(cls, params, names): assert len(params) == 2 subtype = lookup_casstype(params[0]) vsize = params[1] - return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) + return type( + "%s(%s)" % (cls.cass_parameterized_type_with([]), vsize), + (cls,), + {"vector_size": vsize, "subtype": subtype}, + ) @classmethod def deserialize(cls, byts, protocol_version): @@ -1450,26 +1579,43 @@ def deserialize(cls, byts, protocol_version): expected_byte_size = serialized_size * cls.vector_size if len(byts) != expected_byte_size: raise ValueError( - "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ - .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) + "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead".format( + cls.subtype.typename, + cls.vector_size, + expected_byte_size, + len(byts), + ) + ) indexes = (serialized_size * x for x in range(0, cls.vector_size)) - return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + return [ + cls.subtype.deserialize( + byts[idx : idx + serialized_size], protocol_version + ) + for idx in indexes + ] idx = 0 rv = [] - while (len(rv) < cls.vector_size): + while len(rv) < cls.vector_size: try: size, bytes_read = uvint_unpack(byts[idx:]) idx += bytes_read - rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + rv.append( + cls.subtype.deserialize(byts[idx : idx + size], protocol_version) + ) idx += size except: - raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ - .format(len(rv))) + raise ValueError( + "Error reading additional data during vector deserialization after successfully adding {} elements".format( + len(rv) + ) + ) # If we have any additional data in the serialized vector treat that as an error as well if idx < len(byts): - raise ValueError("Additional bytes remaining after vector deserialization completed") + raise ValueError( + "Additional bytes remaining after vector deserialization completed" + ) return rv @classmethod @@ -1477,8 +1623,10 @@ def serialize(cls, v, protocol_version): v_length = len(v) if cls.vector_size != v_length: raise ValueError( - "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\ - .format(cls.vector_size, cls.subtype.typename, v_length)) + "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}".format( + cls.vector_size, cls.subtype.typename, v_length + ) + ) serialized_size = cls.subtype.serial_size() buf = io.BytesIO() @@ -1491,4 +1639,8 @@ def serialize(cls, v, protocol_version): @classmethod def cql_parameterized_type(cls): - return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size) + return "%s<%s, %s>" % ( + cls.typename, + cls.subtype.cql_parameterized_type(), + cls.vector_size, + ) From 07f60f9310e326bae6f26c44090862fff259e72a Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:00:25 +0200 Subject: [PATCH 03/18] fix: materialize filter() result into list for Python 3 safety In Python 2, filter() returned a list. In Python 3, it returns a lazy iterator that can only be consumed once. The column_aliases variable assigned from filter() at metadata.py:2273 may be iterated multiple times downstream (e.g., for length checks and enumeration), which would silently produce empty results on the second pass. Wrap the filter() call in list() to ensure the result is a concrete list that supports repeated iteration, indexing, and len(). --- cassandra/metadata.py | 2222 ++++++++++++++++++++++++++++++----------- 1 file changed, 1628 insertions(+), 594 deletions(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index b85308449e..3d4a89a0b5 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -34,7 +34,12 @@ except ImportError as e: pass -from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized +from cassandra import ( + SignatureDescriptor, + ConsistencyLevel, + InvalidRequest, + Unauthorized, +) import cassandra.cqltypes as types from cassandra.encoder import Encoder from cassandra.marshal import varint_unpack @@ -48,41 +53,272 @@ log = logging.getLogger(__name__) -cql_keywords = set(( - 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', - 'bigint', 'blob', 'boolean', 'by', 'cast', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', - 'counter', 'create', 'custom', 'date', 'decimal', 'default', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop', - 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', - 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', - 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'mbean', 'mbeans', 'modify', 'monotonic', - 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', - 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'scylla_clustering_bound', - 'scylla_counter_shard_list', 'scylla_timeuuid_list_index', 'select', 'set', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', - 'table', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', - 'unset', 'update', 'use', 'user', 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime', - - # DSE specifics - "node", "nodes", "plan", "active", "application", "applications", "java", "executor", "executors", "std_out", "std_err", - "renew", "delegation", "no", "redact", "token", "lowercasestring", "cluster", "authentication", "schemes", "scheme", - "internal", "ldap", "kerberos", "remote", "object", "method", "call", "calls", "search", "schema", "config", "rows", - "columns", "profiles", "commit", "reload", "rebuild", "field", "workpool", "any", "submission", "indices", - "restrict", "unrestrict" -)) +cql_keywords = set( + ( + "add", + "aggregate", + "all", + "allow", + "alter", + "and", + "apply", + "as", + "asc", + "ascii", + "authorize", + "batch", + "begin", + "bigint", + "blob", + "boolean", + "by", + "cast", + "called", + "clustering", + "columnfamily", + "compact", + "contains", + "count", + "counter", + "create", + "custom", + "date", + "decimal", + "default", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "double", + "drop", + "entries", + "execute", + "exists", + "filtering", + "finalfunc", + "float", + "from", + "frozen", + "full", + "function", + "functions", + "grant", + "if", + "in", + "index", + "inet", + "infinity", + "initcond", + "input", + "insert", + "int", + "into", + "is", + "json", + "key", + "keys", + "keyspace", + "keyspaces", + "language", + "limit", + "list", + "login", + "map", + "materialized", + "mbean", + "mbeans", + "modify", + "monotonic", + "nan", + "nologin", + "norecursive", + "nosuperuser", + "not", + "null", + "of", + "on", + "options", + "or", + "order", + "password", + "permission", + "permissions", + "primary", + "rename", + "replace", + "returns", + "revoke", + "role", + "roles", + "schema", + "scylla_clustering_bound", + "scylla_counter_shard_list", + "scylla_timeuuid_list_index", + "select", + "set", + "sfunc", + "smallint", + "static", + "storage", + "stype", + "superuser", + "table", + "text", + "time", + "timestamp", + "timeuuid", + "tinyint", + "to", + "token", + "trigger", + "truncate", + "ttl", + "tuple", + "type", + "unlogged", + "unset", + "update", + "use", + "user", + "users", + "using", + "uuid", + "values", + "varchar", + "varint", + "view", + "where", + "with", + "writetime", + # DSE specifics + "node", + "nodes", + "plan", + "active", + "application", + "applications", + "java", + "executor", + "executors", + "std_out", + "std_err", + "renew", + "delegation", + "no", + "redact", + "token", + "lowercasestring", + "cluster", + "authentication", + "schemes", + "scheme", + "internal", + "ldap", + "kerberos", + "remote", + "object", + "method", + "call", + "calls", + "search", + "schema", + "config", + "rows", + "columns", + "profiles", + "commit", + "reload", + "rebuild", + "field", + "workpool", + "any", + "submission", + "indices", + "restrict", + "unrestrict", + ) +) """ Set of keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ -cql_keywords_unreserved = set(( - 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', - 'count', 'counter', 'custom', 'date', 'decimal', 'deterministic', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', - 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', - 'language', 'list', 'login', 'map', 'monotonic', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', - 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', - 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', - 'varint', 'writetime' -)) +cql_keywords_unreserved = set( + ( + "aggregate", + "all", + "as", + "ascii", + "bigint", + "blob", + "boolean", + "called", + "clustering", + "compact", + "contains", + "count", + "counter", + "custom", + "date", + "decimal", + "deterministic", + "distinct", + "double", + "exists", + "filtering", + "finalfunc", + "float", + "frozen", + "function", + "functions", + "inet", + "initcond", + "input", + "int", + "json", + "key", + "keys", + "keyspaces", + "language", + "list", + "login", + "map", + "monotonic", + "nologin", + "nosuperuser", + "options", + "password", + "permission", + "permissions", + "returns", + "role", + "roles", + "sfunc", + "smallint", + "static", + "storage", + "stype", + "superuser", + "text", + "time", + "timestamp", + "timeuuid", + "tinyint", + "trigger", + "ttl", + "tuple", + "type", + "user", + "users", + "uuid", + "values", + "varchar", + "varint", + "writetime", + ) +) """ Set of unreserved keywords in CQL. @@ -136,13 +372,28 @@ def export_schema_as_string(self): """ return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values()) - def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None, - metadata_request_timeout=None, **kwargs): + def refresh( + self, + connection, + timeout, + target_type=None, + change_type=None, + fetch_size=None, + metadata_request_timeout=None, + **kwargs, + ): host = self.get_host(connection.original_endpoint) server_version = host.release_version if host else None dse_version = host.dse_version if host else None - parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size) + parser = get_schema_parser( + connection, + server_version, + dse_version, + timeout, + metadata_request_timeout, + fetch_size, + ) if not target_type: self._rebuild_all(parser) @@ -150,13 +401,13 @@ def refresh(self, connection, timeout, target_type=None, change_type=None, fetch tt_lower = target_type.lower() try: - parse_method = getattr(parser, 'get_' + tt_lower) + parse_method = getattr(parser, "get_" + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: - update_method = getattr(self, '_update_' + tt_lower) + update_method = getattr(self, "_update_" + tt_lower) update_method(meta) else: - drop_method = getattr(self, '_drop_' + tt_lower) + drop_method = getattr(self, "_drop_" + tt_lower) drop_method(**kwargs) except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) @@ -165,7 +416,9 @@ def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) - old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get(keyspace_meta.name, None) + old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get( + keyspace_meta.name, None + ) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) @@ -176,10 +429,14 @@ def _rebuild_all(self, parser): self._keyspace_added(keyspace_meta.name) # remove not-just-added keyspaces - removed_keyspaces = [name for name in self.keyspaces.keys() - if name not in current_keyspaces] - self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() - if name in current_keyspaces) + removed_keyspaces = [ + name for name in self.keyspaces.keys() if name not in current_keyspaces + ] + self.keyspaces = dict( + (name, meta) + for name, meta in self.keyspaces.items() + if name in current_keyspaces + ) for ksname in removed_keyspaces: self._keyspace_removed(ksname) @@ -189,12 +446,19 @@ def _update_keyspace(self, keyspace_meta, new_user_types=None): self.keyspaces[ks_name] = keyspace_meta if old_keyspace_meta: keyspace_meta.tables = old_keyspace_meta.tables - keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types + keyspace_meta.user_types = ( + new_user_types + if new_user_types is not None + else old_keyspace_meta.user_types + ) keyspace_meta.indexes = old_keyspace_meta.indexes keyspace_meta.functions = old_keyspace_meta.functions keyspace_meta.aggregates = old_keyspace_meta.aggregates keyspace_meta.views = old_keyspace_meta.views - if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): + if ( + keyspace_meta.replication_strategy + != old_keyspace_meta.replication_strategy + ): self._keyspace_updated(ks_name) else: self._keyspace_added(ks_name) @@ -242,7 +506,9 @@ def _drop_type(self, keyspace, type): def _update_function(self, function_meta): try: - self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta + self.keyspaces[function_meta.keyspace].functions[ + function_meta.signature + ] = function_meta except KeyError: # can happen if keyspace disappears while processing async event pass @@ -255,7 +521,9 @@ def _drop_function(self, keyspace, function): def _update_aggregate(self, aggregate_meta): try: - self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta + self.keyspaces[aggregate_meta.keyspace].aggregates[ + aggregate_meta.signature + ] = aggregate_meta except KeyError: pass @@ -289,11 +557,11 @@ def rebuild_token_map(self, partitioner, token_map): For internal use only. """ self.partitioner = partitioner - if partitioner.endswith('RandomPartitioner'): + if partitioner.endswith("RandomPartitioner"): token_class = MD5Token - elif partitioner.endswith('Murmur3Partitioner'): + elif partitioner.endswith("Murmur3Partitioner"): token_class = Murmur3Token - elif partitioner.endswith('ByteOrderedPartitioner'): + elif partitioner.endswith("ByteOrderedPartitioner"): token_class = BytesToken else: self.token_map = None @@ -308,8 +576,7 @@ def rebuild_token_map(self, partitioner, token_map): token_to_host_owner[token] = host all_tokens = sorted(ring) - self.token_map = TokenMap( - token_class, token_to_host_owner, all_tokens, self) + self.token_map = TokenMap(token_class, token_to_host_owner, all_tokens, self) def get_replicas(self, keyspace, key): """ @@ -325,7 +592,7 @@ def get_replicas(self, keyspace, key): return [] def can_support_partitioner(self): - if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: + if self.partitioner.endswith("Murmur3Partitioner") and murmur3 is None: return False else: return True @@ -385,8 +652,11 @@ def get_host_by_host_id(self, host_id): def _get_host_by_address(self, address, port=None): for host in self._hosts.values(): - if (host.broadcast_rpc_address == address and - (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)): + if host.broadcast_rpc_address == address and ( + port is None + or host.broadcast_rpc_port is None + or host.broadcast_rpc_port == port + ): return host return None @@ -408,7 +678,7 @@ def all_hosts_items(self): def trim_if_startswith(s, prefix): if s.startswith(prefix): - return s[len(prefix):] + return s[len(prefix) :] return s @@ -417,14 +687,13 @@ def trim_if_startswith(s, prefix): class ReplicationStrategyTypeType(type): def __new__(metacls, name, bases, dct): - dct.setdefault('name', name) + dct.setdefault("name", name) cls = type.__new__(metacls, name, bases, dct) - if not name.startswith('_'): + if not name.startswith("_"): _replication_strategies[name] = cls return cls - class _ReplicationStrategy(object, metaclass=ReplicationStrategyTypeType): options_map = None @@ -433,7 +702,9 @@ def create(cls, strategy_class, options_map): if not strategy_class: return None - strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) + strategy_name = trim_if_startswith( + strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX + ) rs_class = _replication_strategies.get(strategy_name, None) if rs_class is None: @@ -443,7 +714,12 @@ def create(cls, strategy_class, options_map): try: rs_instance = rs_class(options_map) except Exception as exc: - log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) + log.warning( + "Failed creating %s with options %s: %s", + strategy_name, + options_map, + exc, + ) return None return rs_instance @@ -471,12 +747,14 @@ class _UnknownStrategy(ReplicationStrategy): def __init__(self, name, options_map): self.name = name self.options_map = options_map.copy() if options_map is not None else dict() - self.options_map['class'] = self.name + self.options_map["class"] = self.name def __eq__(self, other): - return (isinstance(other, _UnknownStrategy) and - self.name == other.name and - self.options_map == other.options_map) + return ( + isinstance(other, _UnknownStrategy) + and self.name == other.name + and self.options_map == other.options_map + ) def export_for_schema(self): """ @@ -484,8 +762,10 @@ def export_for_schema(self): suitable for use in a CREATE KEYSPACE statement. """ if self.options_map: - return dict((str(key), str(value)) for key, value in self.options_map.items()) - return "{'class': '%s'}" % (self.name, ) + return dict( + (str(key), str(value)) for key, value in self.options_map.items() + ) + return "{'class': '%s'}" % (self.name,) def make_token_replica_map(self, token_to_host_owner, ring): return {} @@ -517,7 +797,9 @@ class ReplicationFactor(object): def __init__(self, all_replicas, transient_replicas=None): self.all_replicas = all_replicas self.transient_replicas = transient_replicas - self.full_replicas = (all_replicas - transient_replicas) if transient_replicas else all_replicas + self.full_replicas = ( + (all_replicas - transient_replicas) if transient_replicas else all_replicas + ) @staticmethod def create(rf): @@ -529,26 +811,33 @@ def create(rf): all_replicas = int(rf) except ValueError: try: - rf = rf.split('/') + rf = rf.split("/") all_replicas, transient_replicas = int(rf[0]), int(rf[1]) except Exception: - raise ValueError("Unable to determine replication factor from: {}".format(rf)) + raise ValueError( + "Unable to determine replication factor from: {}".format(rf) + ) return ReplicationFactor(all_replicas, transient_replicas) def __str__(self): - return ("%d/%d" % (self.all_replicas, self.transient_replicas) if self.transient_replicas - else "%d" % self.all_replicas) + return ( + "%d/%d" % (self.all_replicas, self.transient_replicas) + if self.transient_replicas + else "%d" % self.all_replicas + ) def __eq__(self, other): if not isinstance(other, ReplicationFactor): return False - return self.all_replicas == other.all_replicas and self.full_replicas == other.full_replicas + return ( + self.all_replicas == other.all_replicas + and self.full_replicas == other.full_replicas + ) class SimpleStrategy(ReplicationStrategy): - replication_factor_info = None """ A :class:`cassandra.metadata.ReplicationFactor` instance. @@ -566,7 +855,9 @@ def replication_factor(self): return self.replication_factor_info.full_replicas def __init__(self, options_map): - self.replication_factor_info = ReplicationFactor.create(options_map['replication_factor']) + self.replication_factor_info = ReplicationFactor.create( + options_map["replication_factor"] + ) def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} @@ -588,8 +879,9 @@ def export_for_schema(self): Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ - return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" \ - % (str(self.replication_factor_info),) + return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" % ( + str(self.replication_factor_info), + ) def __eq__(self, other): if not isinstance(other, SimpleStrategy): @@ -599,7 +891,6 @@ def __eq__(self, other): class NetworkTopologyStrategy(ReplicationStrategy): - dc_replication_factors_info = None """ A map of datacenter names to the :class:`cassandra.metadata.ReplicationFactor` instance for that DC. @@ -615,14 +906,20 @@ class NetworkTopologyStrategy(ReplicationStrategy): def __init__(self, dc_replication_factors): self.dc_replication_factors_info = dict( - (str(k), ReplicationFactor.create(v)) for k, v in dc_replication_factors.items()) + (str(k), ReplicationFactor.create(v)) + for k, v in dc_replication_factors.items() + ) self.dc_replication_factors = dict( - (dc, rf.full_replicas) for dc, rf in self.dc_replication_factors_info.items()) + (dc, rf.full_replicas) + for dc, rf in self.dc_replication_factors_info.items() + ) def make_token_replica_map(self, token_to_host_owner, ring): dc_rf_map = dict( - (dc, full_replicas) for dc, full_replicas in self.dc_replication_factors.items() - if full_replicas > 0) + (dc, full_replicas) + for dc, full_replicas in self.dc_replication_factors.items() + if full_replicas > 0 + ) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC @@ -650,7 +947,6 @@ def make_token_replica_map(self, token_to_host_owner, ring): # go through each DC and find the replicas in that DC for dc in dc_to_token_offset.keys(): - # advance our per-DC index until we're up to at least the # current token in the ring token_offsets = dc_to_token_offset[dc] @@ -667,8 +963,11 @@ def make_token_replica_map(self, token_to_host_owner, ring): num_racks_this_dc = len(dc_racks[dc]) num_hosts_this_dc = len(hosts_per_dc[dc]) - for token_offset_index in range(index, index+num_tokens): - if replicas_remaining == 0 or num_replicas_this_dc == num_hosts_this_dc: + for token_offset_index in range(index, index + num_tokens): + if ( + replicas_remaining == 0 + or num_replicas_this_dc == num_hosts_this_dc + ): break if token_offset_index >= num_tokens: @@ -679,7 +978,10 @@ def make_token_replica_map(self, token_to_host_owner, ring): if host in replicas: continue - if host.rack in racks_placed and len(racks_placed) < num_racks_this_dc: + if ( + host.rack in racks_placed + and len(racks_placed) < num_racks_this_dc + ): skipped_hosts.append(host) continue @@ -804,10 +1106,14 @@ class KeyspaceMetadata(object): _exc_info = None """ set if metadata parsing failed """ - def __init__(self, name, durable_writes, strategy_class, strategy_options, graph_engine=None): + def __init__( + self, name, durable_writes, strategy_class, strategy_options, graph_engine=None + ): self.name = name self.durable_writes = durable_writes - self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) + self.replication_strategy = ReplicationStrategy.create( + strategy_class, strategy_options + ) self.tables = {} self.indexes = {} self.user_types = {} @@ -826,29 +1132,40 @@ def export_as_string(self): including user-defined types and tables. """ # Make sure tables with vertex are exported before tables with edges - tables_with_vertex = [t for t in self.tables.values() if hasattr(t, 'vertex') and t.vertex] + tables_with_vertex = [ + t for t in self.tables.values() if hasattr(t, "vertex") and t.vertex + ] other_tables = [t for t in self.tables.values() if t not in tables_with_vertex] cql = "\n\n".join( - [self.as_cql_query() + ';'] + - self.user_type_strings() + - [f.export_as_string() for f in self.functions.values()] + - [a.export_as_string() for a in self.aggregates.values()] + - [t.export_as_string() for t in tables_with_vertex + other_tables]) + [self.as_cql_query() + ";"] + + self.user_type_strings() + + [f.export_as_string() for f in self.functions.values()] + + [a.export_as_string() for a in self.aggregates.values()] + + [t.export_as_string() for t in tables_with_vertex + other_tables] + ) if self._exc_info: import traceback - ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ - (self.name) + + ret = ( + "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" + % (self.name) + ) for line in traceback.format_exception(*self._exc_info): ret += line - ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql + ret += ( + "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" + % cql + ) return ret if self.virtual: - return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" - "Structure, for reference:*/\n" - "{cql}\n" - "").format(ks=self.name, cql=cql) + return ( + "/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" + "Structure, for reference:*/\n" + "{cql}\n" + "" + ).format(ks=self.name, cql=cql) return cql def as_cql_query(self): @@ -860,8 +1177,11 @@ def as_cql_query(self): return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name)) ret = "CREATE KEYSPACE %s WITH replication = %s " % ( protect_name(self.name), - self.replication_strategy.export_for_schema()) - ret = ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) + self.replication_strategy.export_for_schema(), + ) + ret = ret + ( + " AND durable_writes = %s" % ("true" if self.durable_writes else "false") + ) if self.graph_engine is not None: ret = ret + (" AND graph_engine = '%s'" % self.graph_engine) return ret @@ -921,7 +1241,9 @@ def _drop_table_metadata(self, table_name): def _add_view_metadata(self, view_metadata): try: - self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata + self.tables[view_metadata.base_table_name].views[view_metadata.name] = ( + view_metadata + ) self.views[view_metadata.name] = view_metadata except KeyError: pass @@ -972,7 +1294,8 @@ def as_cql_query(self, formatted=False): ret = "CREATE TYPE %s.%s (%s" % ( protect_name(self.keyspace), protect_name(self.name), - "\n" if formatted else "") + "\n" if formatted else "", + ) if formatted: field_join = ",\n" @@ -990,7 +1313,7 @@ def as_cql_query(self, formatted=False): return ret def export_as_string(self): - return self.as_cql_query(formatted=True) + ';' + return self.as_cql_query(formatted=True) + ";" class Aggregate(object): @@ -1048,9 +1371,18 @@ class Aggregate(object): for a particular input and state. This is available only with DSE >=6.0. """ - def __init__(self, keyspace, name, argument_types, state_func, - state_type, final_func, initial_condition, return_type, - deterministic): + def __init__( + self, + keyspace, + name, + argument_types, + state_func, + state_type, + final_func, + initial_condition, + return_type, + deterministic, + ): self.keyspace = keyspace self.name = name self.argument_types = argument_types @@ -1067,25 +1399,37 @@ def as_cql_query(self, formatted=False): If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ - sep = '\n ' if formatted else ' ' + sep = "\n " if formatted else " " keyspace = protect_name(self.keyspace) name = protect_name(self.name) - type_list = ', '.join([types.strip_frozen(arg_type) for arg_type in self.argument_types]) + type_list = ", ".join( + [types.strip_frozen(arg_type) for arg_type in self.argument_types] + ) state_func = protect_name(self.state_func) state_type = types.strip_frozen(self.state_type) - ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ - "SFUNC %(state_func)s%(sep)s" \ - "STYPE %(state_type)s" % locals() + ret = ( + "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" + "SFUNC %(state_func)s%(sep)s" + "STYPE %(state_type)s" % locals() + ) - ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' - ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' - ret += '{}DETERMINISTIC'.format(sep) if self.deterministic else '' + ret += ( + "".join((sep, "FINALFUNC ", protect_name(self.final_func))) + if self.final_func + else "" + ) + ret += ( + "".join((sep, "INITCOND ", self.initial_condition)) + if self.initial_condition is not None + else "" + ) + ret += "{}DETERMINISTIC".format(sep) if self.deterministic else "" return ret def export_as_string(self): - return self.as_cql_query(formatted=True) + ';' + return self.as_cql_query(formatted=True) + ";" @property def signature(self): @@ -1160,9 +1504,20 @@ class Function(object): monotonic. This is available only for DSE >=6.0. """ - def __init__(self, keyspace, name, argument_types, argument_names, - return_type, language, body, called_on_null_input, - deterministic, monotonic, monotonic_on): + def __init__( + self, + keyspace, + name, + argument_types, + argument_names, + return_type, + language, + body, + called_on_null_input, + deterministic, + monotonic, + monotonic_on, + ): self.keyspace = keyspace self.name = name self.argument_types = argument_types @@ -1183,39 +1538,44 @@ def as_cql_query(self, formatted=False): If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ - sep = '\n ' if formatted else ' ' + sep = "\n " if formatted else " " keyspace = protect_name(self.keyspace) name = protect_name(self.name) - arg_list = ', '.join(["%s %s" % (protect_name(n), types.strip_frozen(t)) - for n, t in zip(self.argument_names, self.argument_types)]) + arg_list = ", ".join( + [ + "%s %s" % (protect_name(n), types.strip_frozen(t)) + for n, t in zip(self.argument_names, self.argument_types) + ] + ) typ = self.return_type lang = self.language body = self.body on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" - deterministic_token = ('DETERMINISTIC{}'.format(sep) - if self.deterministic else - '') - monotonic_tokens = '' # default for nonmonotonic function + deterministic_token = ( + "DETERMINISTIC{}".format(sep) if self.deterministic else "" + ) + monotonic_tokens = "" # default for nonmonotonic function if self.monotonic: # monotonic on all arguments; ignore self.monotonic_on - monotonic_tokens = 'MONOTONIC{}'.format(sep) + monotonic_tokens = "MONOTONIC{}".format(sep) elif self.monotonic_on: # if monotonic == False and monotonic_on is nonempty, we know that # monotonicity was specified with MONOTONIC ON , so there's # exactly 1 value there - monotonic_tokens = 'MONOTONIC ON {}{}'.format(self.monotonic_on[0], - sep) - - return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ - "%(on_null)s ON NULL INPUT%(sep)s" \ - "RETURNS %(typ)s%(sep)s" \ - "%(deterministic_token)s" \ - "%(monotonic_tokens)s" \ - "LANGUAGE %(lang)s%(sep)s" \ - "AS $$%(body)s$$" % locals() + monotonic_tokens = "MONOTONIC ON {}{}".format(self.monotonic_on[0], sep) + + return ( + "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" + "%(on_null)s ON NULL INPUT%(sep)s" + "RETURNS %(typ)s%(sep)s" + "%(deterministic_token)s" + "%(monotonic_tokens)s" + "LANGUAGE %(lang)s%(sep)s" + "AS $$%(body)s$$" % locals() + ) def export_as_string(self): - return self.as_cql_query(formatted=True) + ';' + return self.as_cql_query(formatted=True) + ";" @property def signature(self): @@ -1279,7 +1639,8 @@ def primary_key(self): compaction_options = { "min_compaction_threshold": "min_threshold", "max_compaction_threshold": "max_threshold", - "compaction_strategy_class": "class"} + "compaction_strategy_class": "class", + } triggers = None """ @@ -1309,13 +1670,15 @@ def is_cql_compatible(self): """ if self.virtual: return False - comparator = getattr(self, 'comparator', None) + comparator = getattr(self, "comparator", None) if comparator: # no compact storage with more than one column beyond PK if there # are clustering columns - incompatible = (self.is_compact_storage and - len(self.columns) > len(self.primary_key) + 1 and - len(self.clustering_key) >= 1) + incompatible = ( + self.is_compact_storage + and len(self.columns) > len(self.primary_key) + 1 + and len(self.clustering_key) >= 1 + ) return not incompatible return True @@ -1325,7 +1688,17 @@ def is_cql_compatible(self): Metadata describing configuration for table extensions """ - def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False): + def __init__( + self, + keyspace_name, + name, + partition_key=None, + clustering_key=None, + columns=None, + triggers=None, + options=None, + virtual=False, + ): self.keyspace_name = keyspace_name self.name = name self.partition_key = [] if partition_key is None else partition_key @@ -1346,20 +1719,33 @@ def export_as_string(self): """ if self._exc_info: import traceback - ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \ - (self.keyspace_name, self.name) + + ret = ( + "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" + % (self.keyspace_name, self.name) + ) for line in traceback.format_exception(*self._exc_info): ret += line - ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + ret += ( + "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" + % self._all_as_cql() + ) elif not self.is_cql_compatible: # If we can't produce this table with CQL, comment inline - ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ - (self.keyspace_name, self.name) - ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + ret = ( + "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" + % (self.keyspace_name, self.name) + ) + ret += ( + "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" + % self._all_as_cql() + ) elif self.virtual: - ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n' - 'Structure, for reference:\n' - '{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) + ret = ( + "/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n" + "Structure, for reference:\n" + "{cql}\n*/" + ).format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) else: ret = self._all_as_cql() @@ -1381,7 +1767,9 @@ def _all_as_cql(self): if self.extensions: registry = _RegisteredExtensionType._extension_registry - for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + for k in ( + registry.keys() & self.extensions + ): # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: @@ -1396,10 +1784,11 @@ def as_cql_query(self, formatted=False): extra whitespace will be added to make the query human readable. """ ret = "%s TABLE %s.%s (%s" % ( - ('VIRTUAL' if self.virtual else 'CREATE'), + ("VIRTUAL" if self.virtual else "CREATE"), protect_name(self.keyspace_name), protect_name(self.name), - "\n" if formatted else "") + "\n" if formatted else "", + ) if formatted: column_join = ",\n" @@ -1410,7 +1799,14 @@ def as_cql_query(self, formatted=False): columns = [] for col in self.columns.values(): - columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else '')) + columns.append( + "%s %s%s" + % ( + protect_name(col.name), + col.cql_type, + " static" if col.is_static else "", + ) + ) if len(self.partition_key) == 1 and not self.clustering_key: columns[0] += " PRIMARY KEY" @@ -1422,23 +1818,31 @@ def as_cql_query(self, formatted=False): ret += "%s%sPRIMARY KEY (" % (column_join, padding) if len(self.partition_key) > 1: - ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) + ret += "(%s)" % ", ".join( + protect_name(col.name) for col in self.partition_key + ) else: ret += protect_name(self.partition_key[0].name) if self.clustering_key: - ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) + ret += ", %s" % ", ".join( + protect_name(col.name) for col in self.clustering_key + ) ret += ")" # properties ret += "%s) WITH " % ("\n" if formatted else "") - ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage) + ret += self._property_string( + formatted, self.clustering_key, self.options, self.is_compact_storage + ) return ret @classmethod - def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False): + def _property_string( + cls, formatted, clustering_key, options_map, is_compact_storage=False + ): properties = [] if is_compact_storage: properties.append("COMPACT STORAGE") @@ -1464,21 +1868,25 @@ def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) - actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}')) + actual_options = json.loads( + options_copy.pop("compaction_strategy_options", "{}") + ) value = options_copy.pop("compaction_strategy_class", None) actual_options.setdefault("class", value) - compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()] - ret.append('compaction = {%s}' % ', '.join(compaction_option_strings)) + compaction_option_strings = [ + "'%s': '%s'" % (k, v) for k, v in actual_options.items() + ] + ret.append("compaction = {%s}" % ", ".join(compaction_option_strings)) for system_table_name in cls.compaction_options.keys(): options_copy.pop(system_table_name, None) # delete if present - options_copy.pop('compaction_strategy_option', None) + options_copy.pop("compaction_strategy_option", None) - if not options_copy.get('compression'): - params = json.loads(options_copy.pop('compression_parameters', '{}')) + if not options_copy.get("compression"): + params = json.loads(options_copy.pop("compression_parameters", "{}")) param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()] - ret.append('compression = {%s}' % ', '.join(param_strings)) + ret.append("compression = {%s}" % ", ".join(param_strings)) for name, value in options_copy.items(): if value is not None: @@ -1494,11 +1902,14 @@ class TableMetadataV3(TableMetadata): For C* 3.0+. `option_maps` take a superset of map names, so if nothing changes structurally, new option maps can just be appended to the list. """ + compaction_options = {} option_maps = [ - 'compaction', 'compression', 'caching', - 'nodesync' # added DSE 6.0 + "compaction", + "compression", + "caching", + "nodesync", # added DSE 6.0 ] @property @@ -1515,7 +1926,7 @@ def _make_option_strings(cls, options_map): if isinstance(value, Mapping): del options_copy[option] params = ("'%s': '%s'" % (k, v) for k, v in value.items()) - ret.append("%s = {%s}" % (option, ', '.join(params))) + ret.append("%s = {%s}" % (option, ", ".join(params))) for name, value in options_copy.items(): if value is not None: @@ -1527,7 +1938,6 @@ def _make_option_strings(cls, options_map): class TableMetadataDSE68(TableMetadataV3): - vertex = None """A :class:`.VertexMetadata` instance, if graph enabled""" @@ -1546,18 +1956,21 @@ def as_cql_query(self, formatted=False): ret += self._export_edge_as_cql( self.edge.from_label, self.edge.from_partition_key_columns, - self.edge.from_clustering_columns, "FROM") + self.edge.from_clustering_columns, + "FROM", + ) ret += self._export_edge_as_cql( self.edge.to_label, self.edge.to_partition_key_columns, - self.edge.to_clustering_columns, "TO") + self.edge.to_clustering_columns, + "TO", + ) return ret @staticmethod - def _export_edge_as_cql(label_name, partition_keys, - clustering_columns, keyword): + def _export_edge_as_cql(label_name, partition_keys, clustering_columns, keyword): ret = " %s %s(" % (keyword, protect_name(label_name)) if len(partition_keys) == 1: @@ -1576,6 +1989,7 @@ class TableExtensionInterface(object): """ Defines CQL/DDL for Cassandra table extensions. """ + # limited API for now. Could be expanded as new extension types materialize -- "extend_option_strings", for example @classmethod def after_table_cql(cls, ext_key, ext_blob): @@ -1587,20 +2001,22 @@ def after_table_cql(cls, ext_key, ext_blob): class _RegisteredExtensionType(type): - _extension_registry = {} def __new__(mcs, name, bases, dct): cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct) - if name != 'RegisteredTableExtension': + if name != "RegisteredTableExtension": mcs._extension_registry[cls.name] = cls return cls -class RegisteredTableExtension(TableExtensionInterface, metaclass=_RegisteredExtensionType): +class RegisteredTableExtension( + TableExtensionInterface, metaclass=_RegisteredExtensionType +): """ Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map). """ + name = None """ Name of the extension (key in the map) @@ -1617,13 +2033,13 @@ def protect_names(names): def protect_value(value): if value is None: - return 'NULL' + return "NULL" if isinstance(value, (int, float, bool)): return str(value).lower() return "'%s'" % value.replace("'", "''") -valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') +valid_cql3_word_re = re.compile(r"^[a-z][0-9a-z_]*$") def is_valid_name(name): @@ -1673,7 +2089,9 @@ class ColumnMetadata(object): _cass_type = None - def __init__(self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False): + def __init__( + self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False + ): self.table = table_metadata self.name = column_name self.cql_type = cql_type @@ -1688,6 +2106,7 @@ class IndexMetadata(object): """ A representation of a secondary index on a column. """ + keyspace_name = None """ A string name of the keyspace. """ @@ -1721,7 +2140,8 @@ def as_cql_query(self): protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), - index_target) + index_target, + ) else: class_name = options.pop("class_name") ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( @@ -1729,10 +2149,13 @@ def as_cql_query(self): protect_name(self.keyspace_name), protect_name(self.table_name), index_target, - class_name) + class_name, + ) if options: # PYTHON-1008: `ret` will always be a unicode - opts_cql_encoded = _encoder.cql_encode_all_types(options, as_text_type=True) + opts_cql_encoded = _encoder.cql_encode_all_types( + options, as_text_type=True + ) ret += " WITH OPTIONS = %s" % opts_cql_encoded return ret @@ -1740,7 +2163,7 @@ def export_as_string(self): """ Returns a CQL query string that can be used to recreate this index. """ - return self.as_cql_query() + ';' + return self.as_cql_query() + ";" class TokenMap(object): @@ -1784,16 +2207,24 @@ def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: try: current = self.tokens_to_hosts_by_ks.get(keyspace, None) - if (build_if_absent and current is None) or (not build_if_absent and current is not None): + if (build_if_absent and current is None) or ( + not build_if_absent and current is not None + ): ks_meta = self._metadata.keyspaces.get(keyspace) if ks_meta: - replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) + replica_map = self.replica_map_for_keyspace( + self._metadata.keyspaces[keyspace] + ) self.tokens_to_hosts_by_ks[keyspace] = replica_map except Exception: # should not happen normally, but we don't want to blow up queries because of unexpected meta state # bypass until new map is generated self.tokens_to_hosts_by_ks[keyspace] = {} - log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner) + log.exception( + "Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", + keyspace, + self.token_to_host_owner, + ) def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy @@ -1858,11 +2289,12 @@ def __hash__(self): def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.value) + __str__ = __repr__ -MIN_LONG = -(2 ** 63) -MAX_LONG = (2 ** 63) - 1 +MIN_LONG = -(2**63) +MAX_LONG = (2**63) - 1 class NoMurmur3(Exception): @@ -1870,10 +2302,9 @@ class NoMurmur3(Exception): class HashToken(Token): - @classmethod def from_string(cls, token_string): - """ `token_string` should be the string representation from the server. """ + """`token_string` should be the string representation from the server.""" # The hash partitioners just store the deciman value return cls(int(token_string)) @@ -1892,7 +2323,7 @@ def hash_fn(cls, key): raise NoMurmur3() def __init__(self, token): - """ `token` is an int or string representing the token. """ + """`token` is an int or string representing the token.""" super().__init__(int(token)) @@ -1904,7 +2335,7 @@ class MD5Token(HashToken): @classmethod def hash_fn(cls, key): if isinstance(key, str): - key = key.encode('UTF-8') + key = key.encode("UTF-8") return abs(varint_unpack(md5(key).digest())) @@ -1915,10 +2346,10 @@ class BytesToken(Token): @classmethod def from_string(cls, token_string): - """ `token_string` should be the string representation from the server. """ + """`token_string` should be the string representation from the server.""" # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" if isinstance(token_string, str): - token_string = token_string.encode('ascii') + token_string = token_string.encode("ascii") # The BOP stores a hex string return cls(unhexlify(token_string)) @@ -1939,6 +2370,7 @@ class TriggerMetadata(object): A dict mapping trigger option names to their specific settings for this table. """ + def __init__(self, table_metadata, trigger_name, options=None): self.table = table_metadata self.name = trigger_name @@ -1949,12 +2381,12 @@ def as_cql_query(self): protect_name(self.name), protect_name(self.table.keyspace_name), protect_name(self.table.name), - protect_value(self.options['class']) + protect_value(self.options["class"]), ) return ret def export_as_string(self): - return self.as_cql_query() + ';' + return self.as_cql_query() + ";" class _SchemaParser(object): @@ -1964,7 +2396,9 @@ def __init__(self, connection, timeout, fetch_size, metadata_request_timeout): self.fetch_size = fetch_size self.metadata_request_timeout = metadata_request_timeout - def _handle_results(self, success, result, expected_failures=tuple(), query_msg=None, timeout=None): + def _handle_results( + self, success, result, expected_failures=tuple(), query_msg=None, timeout=None + ): """ Given a bool and a ResultSet (the form returned per result from Connection.wait_for_responses), return a dictionary containing the @@ -1990,13 +2424,21 @@ def _handle_results(self, success, result, expected_failures=tuple(), query_msg= return [] elif success: if result.paging_state and query_msg: + def get_next_pages(): next_result = None while True: - query_msg.paging_state = next_result.paging_state if next_result else result.paging_state - next_success, next_result = self.connection.wait_for_response(query_msg, timeout=timeout, - fail_on_error=False) - if not next_success and isinstance(next_result, expected_failures): + query_msg.paging_state = ( + next_result.paging_state + if next_result + else result.paging_state + ) + next_success, next_result = self.connection.wait_for_response( + query_msg, timeout=timeout, fail_on_error=False + ) + if not next_success and isinstance( + next_result, expected_failures + ): continue elif not next_success: raise next_result @@ -2007,7 +2449,9 @@ def get_next_pages(): yield next_result.parsed_rows result.parsed_rows += itertools.chain(*get_next_pages()) - return dict_factory(result.column_names, result.parsed_rows) if result else [] + return ( + dict_factory(result.column_names, result.parsed_rows) if result else [] + ) else: raise result @@ -2016,11 +2460,20 @@ def _query_build_row(self, query_string, build_func): return result[0] if result else None def _query_build_rows(self, query_string, build_func): - query = QueryMessage(query=maybe_add_timeout_to_query(query_string, self.metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE, fetch_size=self.fetch_size) - responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) + query = QueryMessage( + query=maybe_add_timeout_to_query( + query_string, self.metadata_request_timeout + ), + consistency_level=ConsistencyLevel.ONE, + fetch_size=self.fetch_size, + ) + responses = self.connection.wait_for_responses( + (query), timeout=self.timeout, fail_on_error=False + ) (success, response) = responses[0] - results = self._handle_results(success, response, expected_failures=(InvalidRequest), query_msg=query) + results = self._handle_results( + success, response, expected_failures=(InvalidRequest), query_msg=query + ) if not results: log.debug("user types table not found") return [build_func(row) for row in results] @@ -2030,6 +2483,7 @@ class SchemaParserV22(_SchemaParser): """ For C* 2.2+ """ + _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" @@ -2038,9 +2492,9 @@ class SchemaParserV22(_SchemaParser): _SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions" _SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates" - _table_name_col = 'columnfamily_name' + _table_name_col = "columnfamily_name" - _function_agg_arument_type_col = 'signature' + _function_agg_arument_type_col = "signature" recognized_table_options = ( "comment", @@ -2048,7 +2502,7 @@ class SchemaParserV22(_SchemaParser): "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() "local_read_repair_chance", "replicate_on_write", - 'in_memory', + "in_memory", "gc_grace_seconds", "bloom_filter_fp_chance", "caching", @@ -2065,10 +2519,13 @@ class SchemaParserV22(_SchemaParser): "memtable_flush_period_in_ms", "populate_io_cache_on_flush", "compression", - "default_time_to_live") + "default_time_to_live", + ) def __init__(self, connection, timeout, fetch_size, metadata_request_timeout): - super(SchemaParserV22, self).__init__(connection, timeout, fetch_size, metadata_request_timeout) + super(SchemaParserV22, self).__init__( + connection, timeout, fetch_size, metadata_request_timeout + ) self.keyspaces_result = [] self.tables_result = [] self.columns_result = [] @@ -2107,61 +2564,107 @@ def get_all_keyspaces(self): agg = self._build_aggregate(agg_row) keyspace_meta.aggregates[agg.signature] = agg except Exception: - log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name) + log.exception( + "Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", + keyspace_meta.name, + ) keyspace_meta._exc_info = sys.exc_info() yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE - where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder) + where_clause = bind_params( + " WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), + (keyspace, table), + _encoder, + ) cf_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_COLUMN_FAMILIES + where_clause, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_COLUMN_FAMILIES + where_clause, + self.metadata_request_timeout, + ), consistency_level=cl, ) col_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_COLUMNS + where_clause, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS + where_clause, self.metadata_request_timeout + ), consistency_level=cl, ) triggers_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout + ), consistency_level=cl, ) - (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ - = self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False) + ( + (cf_success, cf_result), + (col_success, col_result), + (triggers_success, triggers_result), + ) = self.connection.wait_for_responses( + cf_query, + col_query, + triggers_query, + timeout=self.timeout, + fail_on_error=False, + ) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) # the triggers table doesn't exist in C* 1.2 - triggers_result = self._handle_results(triggers_success, triggers_result, - expected_failures=InvalidRequest) + triggers_result = self._handle_results( + triggers_success, triggers_result, expected_failures=InvalidRequest + ) if table_result: - return self._build_table_metadata(table_result[0], col_result, triggers_result) + return self._build_table_metadata( + table_result[0], col_result, triggers_result + ) def get_type(self, keyspaces, keyspace, type): - where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) - return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) + where_clause = bind_params( + " WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder + ) + return self._query_build_row( + self._SELECT_TYPES + where_clause, self._build_user_type + ) def get_types_map(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) - types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) + types = self._query_build_rows( + self._SELECT_TYPES + where_clause, self._build_user_type + ) return dict((t.name, t) for t in types) def get_function(self, keyspaces, keyspace, function): - where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), - (keyspace, function.name, function.argument_types), _encoder) - return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function) + where_clause = bind_params( + " WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" + % (self._function_agg_arument_type_col,), + (keyspace, function.name, function.argument_types), + _encoder, + ) + return self._query_build_row( + self._SELECT_FUNCTIONS + where_clause, self._build_function + ) def get_aggregate(self, keyspaces, keyspace, aggregate): - where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), - (keyspace, aggregate.name, aggregate.argument_types), _encoder) + where_clause = bind_params( + " WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" + % (self._function_agg_arument_type_col,), + (keyspace, aggregate.name, aggregate.argument_types), + _encoder, + ) - return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate) + return self._query_build_row( + self._SELECT_AGGREGATES + where_clause, self._build_aggregate + ) def get_keyspace(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) - return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata) + return self._query_build_row( + self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata + ) @classmethod def _build_keyspace_metadata(cls, row): @@ -2169,9 +2672,11 @@ def _build_keyspace_metadata(cls, row): ksm = cls._build_keyspace_metadata_internal(row) except Exception: name = row["keyspace_name"] - ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {}) + ksm = KeyspaceMetadata(name, False, "UNKNOWN", {}) ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances - log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row) + log.exception( + "Error while parsing metadata for keyspace %s row(%s)", name, row + ) return ksm @staticmethod @@ -2184,45 +2689,71 @@ def _build_keyspace_metadata_internal(row): @classmethod def _build_user_type(cls, usertype_row): - field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types'])) - return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], - usertype_row['field_names'], field_types) + field_types = list(map(cls._schema_type_to_cql, usertype_row["field_types"])) + return UserType( + usertype_row["keyspace_name"], + usertype_row["type_name"], + usertype_row["field_names"], + field_types, + ) @classmethod def _build_function(cls, function_row): - return_type = cls._schema_type_to_cql(function_row['return_type']) - deterministic = function_row.get('deterministic', False) - monotonic = function_row.get('monotonic', False) - monotonic_on = function_row.get('monotonic_on', ()) - return Function(function_row['keyspace_name'], function_row['function_name'], - function_row[cls._function_agg_arument_type_col], function_row['argument_names'], - return_type, function_row['language'], function_row['body'], - function_row['called_on_null_input'], - deterministic, monotonic, monotonic_on) + return_type = cls._schema_type_to_cql(function_row["return_type"]) + deterministic = function_row.get("deterministic", False) + monotonic = function_row.get("monotonic", False) + monotonic_on = function_row.get("monotonic_on", ()) + return Function( + function_row["keyspace_name"], + function_row["function_name"], + function_row[cls._function_agg_arument_type_col], + function_row["argument_names"], + return_type, + function_row["language"], + function_row["body"], + function_row["called_on_null_input"], + deterministic, + monotonic, + monotonic_on, + ) @classmethod def _build_aggregate(cls, aggregate_row): - cass_state_type = types.lookup_casstype(aggregate_row['state_type']) - initial_condition = aggregate_row['initcond'] + cass_state_type = types.lookup_casstype(aggregate_row["state_type"]) + initial_condition = aggregate_row["initcond"] if initial_condition is not None: - initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) + initial_condition = _encoder.cql_encode_all_types( + cass_state_type.deserialize(initial_condition, 3) + ) state_type = _cql_from_cass_type(cass_state_type) - return_type = cls._schema_type_to_cql(aggregate_row['return_type']) - return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], - aggregate_row['signature'], aggregate_row['state_func'], state_type, - aggregate_row['final_func'], initial_condition, return_type, - aggregate_row.get('deterministic', False)) + return_type = cls._schema_type_to_cql(aggregate_row["return_type"]) + return Aggregate( + aggregate_row["keyspace_name"], + aggregate_row["aggregate_name"], + aggregate_row["signature"], + aggregate_row["state_func"], + state_type, + aggregate_row["final_func"], + initial_condition, + return_type, + aggregate_row.get("deterministic", False), + ) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): keyspace_name = row["keyspace_name"] cfname = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname] - trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] + trigger_rows = ( + trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] + ) if not col_rows: # CASSANDRA-8487 - log.warning("Building table metadata with no column meta for %s.%s", - keyspace_name, cfname) + log.warning( + "Building table metadata with no column meta for %s.%s", + keyspace_name, + cfname, + ) table_meta = TableMetadata(keyspace_name, cfname) @@ -2232,23 +2763,30 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) is_composite_comparator = issubclass(comparator, types.CompositeType) - column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) + column_name_types = ( + comparator.subtypes if is_composite_comparator else (comparator,) + ) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] column_aliases = row.get("column_aliases", None) - clustering_rows = [r for r in col_rows - if r.get('type', None) == "clustering_key"] + clustering_rows = [ + r for r in col_rows if r.get("type", None) == "clustering_key" + ] if len(clustering_rows) > 1: - clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) + clustering_rows = sorted( + clustering_rows, key=lambda row: row.get("component_index") + ) if column_aliases is not None: column_aliases = json.loads(column_aliases) - if not column_aliases: # json load failed or column_aliases empty PYTHON-562 - column_aliases = [r.get('column_name') for r in clustering_rows] + if ( + not column_aliases + ): # json load failed or column_aliases empty PYTHON-562 + column_aliases = [r.get("column_name") for r in clustering_rows] if is_composite_comparator: if issubclass(last_col, types.ColumnToCollectionType): @@ -2256,8 +2794,11 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): is_compact = False has_value = False clustering_size = num_column_name_components - 2 - elif (len(column_aliases) == num_column_name_components - 1 and - issubclass(last_col, types.UTF8Type)): + elif len( + column_aliases + ) == num_column_name_components - 1 and issubclass( + last_col, types.UTF8Type + ): # aliases? is_compact = False has_value = False @@ -2269,8 +2810,8 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): clustering_size = num_column_name_components # Some thrift tables define names in composite types (see PYTHON-192) - if not column_aliases and hasattr(comparator, 'fieldnames'): - column_aliases = filter(None, comparator.fieldnames) + if not column_aliases and hasattr(comparator, "fieldnames"): + column_aliases = list(filter(None, comparator.fieldnames)) else: is_compact = True if column_aliases or not col_rows or is_dct_comparator: @@ -2281,25 +2822,34 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): clustering_size = 0 # partition key - partition_rows = [r for r in col_rows - if r.get('type', None) == "partition_key"] + partition_rows = [ + r for r in col_rows if r.get("type", None) == "partition_key" + ] if len(partition_rows) > 1: - partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) + partition_rows = sorted( + partition_rows, key=lambda row: row.get("component_index") + ) key_aliases = row.get("key_aliases") if key_aliases is not None: key_aliases = json.loads(key_aliases) if key_aliases else [] else: # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. - key_aliases = [r.get('column_name') for r in partition_rows] + key_aliases = [r.get("column_name") for r in partition_rows] key_validator = row.get("key_validator") if key_validator is not None: key_type = types.lookup_casstype(key_validator) - key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] + key_types = ( + key_type.subtypes + if issubclass(key_type, types.CompositeType) + else [key_type] + ) else: - key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] + key_types = [ + types.lookup_casstype(r.get("validator")) for r in partition_rows + ] for i, col_type in enumerate(key_types): if len(key_aliases) > i: @@ -2309,7 +2859,9 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): else: column_name = "key%d" % i - col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type()) + col = ColumnMetadata( + table_meta, column_name, col_type.cql_parameterized_type() + ) table_meta.columns[column_name] = col table_meta.partition_key.append(col) @@ -2323,14 +2875,17 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) is_reversed = types.is_reversed_casstype(data_type) - col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed) + col = ColumnMetadata( + table_meta, column_name, cql_type, is_reversed=is_reversed + ) table_meta.columns[column_name] = col table_meta.clustering_key.append(col) # value alias (if present) if has_value: - value_alias_rows = [r for r in col_rows - if r.get('type', None) == "compact_value"] + value_alias_rows = [ + r for r in col_rows if r.get("type", None) == "compact_value" + ] if not key_aliases: # TODO are we checking the right thing here? value_alias = "value" @@ -2338,14 +2893,16 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): value_alias = row.get("value_alias", None) if value_alias is None and value_alias_rows: # CASSANDRA-8487 # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. - value_alias = value_alias_rows[0].get('column_name') + value_alias = value_alias_rows[0].get("column_name") default_validator = row.get("default_validator") if default_validator: validator = types.lookup_casstype(default_validator) else: if value_alias_rows: # CASSANDRA-8487 - validator = types.lookup_casstype(value_alias_rows[0].get('validator')) + validator = types.lookup_casstype( + value_alias_rows[0].get("validator") + ) cql_type = _cql_from_cass_type(validator) col = ColumnMetadata(table_meta, value_alias, cql_type) @@ -2369,13 +2926,21 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): table_meta.is_compact_storage = is_compact except Exception: table_meta._exc_info = sys.exc_info() - log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows) + log.exception( + "Error while parsing metadata for table %s.%s row(%s) columns(%s)", + keyspace_name, + cfname, + row, + col_rows, + ) return table_meta def _build_table_options(self, row): - """ Setup the mostly-non-schema table options, like caching settings """ - options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row) + """Setup the mostly-non-schema table options, like caching settings""" + options = dict( + (o, row.get(o)) for o in self.recognized_table_options if o in row + ) # the option name when creating tables is "dclocal_read_repair_chance", # but the column name in system.schema_columnfamilies is @@ -2396,7 +2961,9 @@ def _build_column_metadata(cls, table_metadata, row): cql_type = _cql_from_cass_type(data_type) is_static = row.get("type", None) == "static" is_reversed = types.is_reversed_casstype(data_type) - column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) + column_meta = ColumnMetadata( + table_metadata, name, cql_type, is_static, is_reversed + ) column_meta._cass_type = data_type return column_meta @@ -2413,7 +2980,7 @@ def _build_index_metadata(column_metadata, row): target = protect_name(column_metadata.name) if kind != "CUSTOM": if "index_keys" in options: - target = 'keys(%s)' % (target,) + target = "keys(%s)" % (target,) elif "index_values" in options: # don't use any "function" for collection values pass @@ -2423,12 +2990,21 @@ def _build_index_metadata(column_metadata, row): # there is no special index option for full-collection # indexes. data_type = column_metadata._cass_type - collection_types = ('map', 'set', 'list') - if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: + collection_types = ("map", "set", "list") + if ( + data_type.typename == "frozen" + and data_type.subtypes[0].typename in collection_types + ): # no index option for full-collection index - target = 'full(%s)' % (target,) - options['target'] = target - return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options) + target = "full(%s)" % (target,) + options["target"] = target + return IndexMetadata( + column_metadata.table.keyspace_name, + column_metadata.table.name, + index_name, + kind, + options, + ) @staticmethod def _build_trigger_metadata(table_metadata, row): @@ -2441,44 +3017,59 @@ def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_KEYSPACES, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_KEYSPACES, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_COLUMN_FAMILIES, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_COLUMN_FAMILIES, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_COLUMNS, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TYPES, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_TYPES, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_FUNCTIONS, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_FUNCTIONS, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_AGGREGATES, self.metadata_request_timeout + ), consistency_level=cl, ), QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS, self.metadata_request_timeout + ), consistency_level=cl, - ) + ), ] - ((ks_success, ks_result), - (table_success, table_result), - (col_success, col_result), - (types_success, types_result), - (functions_success, functions_result), - (aggregates_success, aggregates_result), - (triggers_success, triggers_result)) = ( - self.connection.wait_for_responses(*queries, timeout=self.timeout, - fail_on_error=False) + ( + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + ) = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False ) self.keyspaces_result = self._handle_results(ks_success, ks_result) @@ -2487,19 +3078,25 @@ def _query_all(self): # if we're connected to Cassandra < 2.0, the triggers table will not exist if triggers_success: - self.triggers_result = dict_factory(triggers_result.column_names, triggers_result.parsed_rows) + self.triggers_result = dict_factory( + triggers_result.column_names, triggers_result.parsed_rows + ) else: if isinstance(triggers_result, InvalidRequest): log.debug("triggers table not found") elif isinstance(triggers_result, Unauthorized): - log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " - "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") + log.warning( + "this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " + "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings." + ) else: raise triggers_result # if we're connected to Cassandra < 2.1, the usertypes table will not exist if types_success: - self.types_result = dict_factory(types_result.column_names, types_result.parsed_rows) + self.types_result = dict_factory( + types_result.column_names, types_result.parsed_rows + ) else: if isinstance(types_result, InvalidRequest): log.debug("user types table not found") @@ -2509,7 +3106,9 @@ def _query_all(self): # functions were introduced in Cassandra 2.2 if functions_success: - self.functions_result = dict_factory(functions_result.column_names, functions_result.parsed_rows) + self.functions_result = dict_factory( + functions_result.column_names, functions_result.parsed_rows + ) else: if isinstance(functions_result, InvalidRequest): log.debug("user functions table not found") @@ -2518,7 +3117,9 @@ def _query_all(self): # aggregates were introduced in Cassandra 2.2 if aggregates_success: - self.aggregates_result = dict_factory(aggregates_result.column_names, aggregates_result.parsed_rows) + self.aggregates_result = dict_factory( + aggregates_result.column_names, aggregates_result.parsed_rows + ) else: if isinstance(aggregates_result, InvalidRequest): log.debug("user aggregates table not found") @@ -2567,6 +3168,7 @@ class SchemaParserV3(SchemaParserV22): """ For C* 3.0+ """ + _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" @@ -2577,32 +3179,35 @@ class SchemaParserV3(SchemaParserV22): _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" - _table_name_col = 'table_name' + _table_name_col = "table_name" - _function_agg_arument_type_col = 'argument_types' + _function_agg_arument_type_col = "argument_types" _table_metadata_class = TableMetadataV3 recognized_table_options = ( - 'bloom_filter_fp_chance', - 'caching', - 'cdc', - 'comment', - 'compaction', - 'compression', - 'crc_check_chance', - 'dclocal_read_repair_chance', - 'default_time_to_live', - 'in_memory', - 'gc_grace_seconds', - 'max_index_interval', - 'memtable_flush_period_in_ms', - 'min_index_interval', - 'read_repair_chance', - 'speculative_retry') + "bloom_filter_fp_chance", + "caching", + "cdc", + "comment", + "compaction", + "compression", + "crc_check_chance", + "dclocal_read_repair_chance", + "default_time_to_live", + "in_memory", + "gc_grace_seconds", + "max_index_interval", + "memtable_flush_period_in_ms", + "min_index_interval", + "read_repair_chance", + "speculative_retry", + ) def __init__(self, connection, timeout, fetch_size, metadata_request_timeout): - super(SchemaParserV3, self).__init__(connection, timeout, fetch_size, metadata_request_timeout) + super(SchemaParserV3, self).__init__( + connection, timeout, fetch_size, metadata_request_timeout + ) self.indexes_result = [] self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_view_rows = defaultdict(list) @@ -2617,40 +3222,82 @@ def get_all_keyspaces(self): def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE fetch_size = self.fetch_size - where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) + where_clause = bind_params( + " WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), + (keyspace, table), + _encoder, + ) cf_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TABLES + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + query=maybe_add_timeout_to_query( + self._SELECT_TABLES + where_clause, self.metadata_request_timeout + ), + consistency_level=cl, + fetch_size=fetch_size, + ) col_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_COLUMNS + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS + where_clause, self.metadata_request_timeout + ), + consistency_level=cl, + fetch_size=fetch_size, + ) indexes_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_INDEXES + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + query=maybe_add_timeout_to_query( + self._SELECT_INDEXES + where_clause, self.metadata_request_timeout + ), + consistency_level=cl, + fetch_size=fetch_size, + ) triggers_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout + ), + consistency_level=cl, + fetch_size=fetch_size, + ) # in protocol v4 we don't know if this event is a view or a table, so we look for both - where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) + where_clause = bind_params( + " WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder + ) view_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_VIEWS + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) - ((cf_success, cf_result), (col_success, col_result), - (indexes_sucess, indexes_result), (triggers_success, triggers_result), - (view_success, view_result)) = ( - self.connection.wait_for_responses( - cf_query, col_query, indexes_query, triggers_query, - view_query, timeout=self.timeout, fail_on_error=False) + query=maybe_add_timeout_to_query( + self._SELECT_VIEWS + where_clause, self.metadata_request_timeout + ), + consistency_level=cl, + fetch_size=fetch_size, + ) + ( + (cf_success, cf_result), + (col_success, col_result), + (indexes_sucess, indexes_result), + (triggers_success, triggers_result), + (view_success, view_result), + ) = self.connection.wait_for_responses( + cf_query, + col_query, + indexes_query, + triggers_query, + view_query, + timeout=self.timeout, + fail_on_error=False, ) table_result = self._handle_results(cf_success, cf_result, query_msg=cf_query) col_result = self._handle_results(col_success, col_result, query_msg=col_query) if table_result: - indexes_result = self._handle_results(indexes_sucess, indexes_result, query_msg=indexes_query) - triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query) - return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) + indexes_result = self._handle_results( + indexes_sucess, indexes_result, query_msg=indexes_query + ) + triggers_result = self._handle_results( + triggers_success, triggers_result, query_msg=triggers_query + ) + return self._build_table_metadata( + table_result[0], col_result, triggers_result, indexes_result + ) - view_result = self._handle_results(view_success, view_result, query_msg=view_query) + view_result = self._handle_results( + view_success, view_result, query_msg=view_query + ) if view_result: return self._build_view_metadata(view_result[0], col_result) @@ -2664,27 +3311,46 @@ def _build_keyspace_metadata_internal(row): @staticmethod def _build_aggregate(aggregate_row): - return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], - aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], - aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type'], - aggregate_row.get('deterministic', False)) + return Aggregate( + aggregate_row["keyspace_name"], + aggregate_row["aggregate_name"], + aggregate_row["argument_types"], + aggregate_row["state_func"], + aggregate_row["state_type"], + aggregate_row["final_func"], + aggregate_row["initcond"], + aggregate_row["return_type"], + aggregate_row.get("deterministic", False), + ) - def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False): + def _build_table_metadata( + self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False + ): keyspace_name = row["keyspace_name"] table_name = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name] - trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] - index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] + trigger_rows = ( + trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] + ) + index_rows = ( + index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] + ) - table_meta = self._table_metadata_class(keyspace_name, table_name, virtual=virtual) + table_meta = self._table_metadata_class( + keyspace_name, table_name, virtual=virtual + ) try: table_meta.options = self._build_table_options(row) - flags = row.get('flags', set()) + flags = row.get("flags", set()) if flags: - is_dense = 'dense' in flags - compact_static = not is_dense and 'super' not in flags and 'compound' not in flags - table_meta.is_compact_storage = is_dense or 'super' in flags or 'compound' not in flags + is_dense = "dense" in flags + compact_static = ( + not is_dense and "super" not in flags and "compound" not in flags + ) + table_meta.is_compact_storage = ( + is_dense or "super" in flags or "compound" not in flags + ) elif virtual: compact_static = False table_meta.is_compact_storage = False @@ -2694,7 +3360,9 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row table_meta.is_compact_storage = True is_dense = False - self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) + self._build_table_columns( + table_meta, col_rows, compact_static, is_dense, virtual + ) for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) @@ -2705,43 +3373,56 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row if index_meta: table_meta.indexes[index_meta.name] = index_meta - table_meta.extensions = row.get('extensions', {}) + table_meta.extensions = row.get("extensions", {}) except Exception: table_meta._exc_info = sys.exc_info() - log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) + log.exception( + "Error while parsing metadata for table %s.%s row(%s) columns(%s)", + keyspace_name, + table_name, + row, + col_rows, + ) return table_meta def _build_table_options(self, row): - """ Setup the mostly-non-schema table options, like caching settings """ + """Setup the mostly-non-schema table options, like caching settings""" return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) - def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): + def _build_table_columns( + self, meta, col_rows, compact_static=False, is_dense=False, virtual=False + ): # partition key - partition_rows = [r for r in col_rows - if r.get('kind', None) == "partition_key"] + partition_rows = [r for r in col_rows if r.get("kind", None) == "partition_key"] if len(partition_rows) > 1: - partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) + partition_rows = sorted(partition_rows, key=lambda row: row.get("position")) for r in partition_rows: # we have to add meta here (and not in the later loop) because TableMetadata.columns is an # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta - meta.partition_key.append(meta.columns[r.get('column_name')]) + meta.partition_key.append(meta.columns[r.get("column_name")]) # clustering key if not compact_static: - clustering_rows = [r for r in col_rows - if r.get('kind', None) == "clustering"] + clustering_rows = [ + r for r in col_rows if r.get("kind", None) == "clustering" + ] if len(clustering_rows) > 1: - clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) + clustering_rows = sorted( + clustering_rows, key=lambda row: row.get("position") + ) for r in clustering_rows: column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta - meta.clustering_key.append(meta.columns[r.get('column_name')]) + meta.clustering_key.append(meta.columns[r.get("column_name")]) - for col_row in (r for r in col_rows - if r.get('kind', None) not in ('partition_key', 'clustering_key')): + for col_row in ( + r + for r in col_rows + if r.get("kind", None) not in ("partition_key", "clustering_key") + ): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue @@ -2760,10 +3441,16 @@ def _build_view_metadata(self, row, col_rows=None): include_all_columns = row["include_all_columns"] where_clause = row["where_clause"] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name] - view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, - include_all_columns, where_clause, self._build_table_options(row)) + view_meta = MaterializedViewMetadata( + keyspace_name, + view_name, + base_table_name, + include_all_columns, + where_clause, + self._build_table_options(row), + ) self._build_table_columns(view_meta, col_rows) - view_meta.extensions = row.get('extensions', {}) + view_meta.extensions = row.get("extensions", {}) return view_meta @@ -2773,7 +3460,9 @@ def _build_column_metadata(table_metadata, row): cql_type = row["type"] is_static = row.get("kind", None) == "static" is_reversed = row["clustering_order"].upper() == "DESC" - column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) + column_meta = ColumnMetadata( + table_metadata, name, cql_type, is_static, is_reversed + ) return column_meta @staticmethod @@ -2782,7 +3471,13 @@ def _build_index_metadata(table_metadata, row): kind = row.get("kind") if index_name or kind: index_options = row.get("options") - return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options) + return IndexMetadata( + table_metadata.keyspace_name, + table_metadata.name, + index_name, + kind, + index_options, + ) else: return None @@ -2797,47 +3492,112 @@ def _query_all(self): cl = ConsistencyLevel.ONE fetch_size = self.fetch_size queries = [ - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_KEYSPACES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TABLES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_COLUMNS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TYPES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_FUNCTIONS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_KEYSPACES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TABLES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TYPES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_FUNCTIONS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_AGGREGATES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_INDEXES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIEWS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), ] - ((ks_success, ks_result), - (table_success, table_result), - (col_success, col_result), - (types_success, types_result), - (functions_success, functions_result), - (aggregates_success, aggregates_result), - (triggers_success, triggers_result), - (indexes_success, indexes_result), - (views_success, views_result)) = self.connection.wait_for_responses( - *queries, timeout=self.timeout, fail_on_error=False - ) - - self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) - self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) - self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) - self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) - self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) - self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + ( + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + ) = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False + ) + + self.keyspaces_result = self._handle_results( + ks_success, ks_result, query_msg=queries[0] + ) + self.tables_result = self._handle_results( + table_success, table_result, query_msg=queries[1] + ) + self.columns_result = self._handle_results( + col_success, col_result, query_msg=queries[2] + ) + self.triggers_result = self._handle_results( + triggers_success, triggers_result, query_msg=queries[6] + ) + self.types_result = self._handle_results( + types_success, types_result, query_msg=queries[3] + ) + self.functions_result = self._handle_results( + functions_success, functions_result, query_msg=queries[4] + ) + self.aggregates_result = self._handle_results( + aggregates_success, aggregates_result, query_msg=queries[5] + ) + self.indexes_result = self._handle_results( + indexes_success, indexes_result, query_msg=queries[7] + ) + self.views_result = self._handle_results( + views_success, views_result, query_msg=queries[8] + ) self._aggregate_results() @@ -2863,35 +3623,37 @@ class SchemaParserDSE60(SchemaParserV3): """ For DSE 6.0+ """ - recognized_table_options = (SchemaParserV3.recognized_table_options + - ("nodesync",)) + recognized_table_options = SchemaParserV3.recognized_table_options + ("nodesync",) -class SchemaParserV4(SchemaParserV3): +class SchemaParserV4(SchemaParserV3): recognized_table_options = ( - 'additional_write_policy', - 'bloom_filter_fp_chance', - 'caching', - 'cdc', - 'comment', - 'compaction', - 'compression', - 'crc_check_chance', - 'default_time_to_live', - 'gc_grace_seconds', - 'max_index_interval', - 'memtable_flush_period_in_ms', - 'min_index_interval', - 'read_repair', - 'speculative_retry') - - _SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces' - _SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables' - _SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns' + "additional_write_policy", + "bloom_filter_fp_chance", + "caching", + "cdc", + "comment", + "compaction", + "compression", + "crc_check_chance", + "default_time_to_live", + "gc_grace_seconds", + "max_index_interval", + "memtable_flush_period_in_ms", + "min_index_interval", + "read_repair", + "speculative_retry", + ) + + _SELECT_VIRTUAL_KEYSPACES = "SELECT * from system_virtual_schema.keyspaces" + _SELECT_VIRTUAL_TABLES = "SELECT * from system_virtual_schema.tables" + _SELECT_VIRTUAL_COLUMNS = "SELECT * from system_virtual_schema.columns" def __init__(self, connection, timeout, fetch_size, metadata_request_timeout): - super(SchemaParserV4, self).__init__(connection, timeout, fetch_size, metadata_request_timeout) + super(SchemaParserV4, self).__init__( + connection, timeout, fetch_size, metadata_request_timeout + ) self.virtual_keyspaces_rows = defaultdict(list) self.virtual_tables_rows = defaultdict(list) self.virtual_columns_rows = defaultdict(lambda: defaultdict(list)) @@ -2903,35 +3665,96 @@ def _query_all(self): fetch_size = self.fetch_size queries = [ # copied from V3 - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_KEYSPACES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TABLES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_COLUMNS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TYPES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_FUNCTIONS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_KEYSPACES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TABLES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TYPES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_FUNCTIONS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_AGGREGATES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_INDEXES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIEWS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), # V4-only queries - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_KEYSPACES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_TABLES, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_COLUMNS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_KEYSPACES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_TABLES, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_COLUMNS, self.metadata_request_timeout + ), + fetch_size=fetch_size, + consistency_level=cl, + ), ] responses = self.connection.wait_for_responses( - *queries, timeout=self.timeout, fail_on_error=False) + *queries, timeout=self.timeout, fail_on_error=False + ) ( # copied from V3 (ks_success, ks_result), @@ -2946,33 +3769,57 @@ def _query_all(self): # V4-only responses (virtual_ks_success, virtual_ks_result), (virtual_table_success, virtual_table_result), - (virtual_column_success, virtual_column_result) + (virtual_column_success, virtual_column_result), ) = responses # copied from V3 - self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) - self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) - self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) - self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) - self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) - self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + self.keyspaces_result = self._handle_results( + ks_success, ks_result, query_msg=queries[0] + ) + self.tables_result = self._handle_results( + table_success, table_result, query_msg=queries[1] + ) + self.columns_result = self._handle_results( + col_success, col_result, query_msg=queries[2] + ) + self.triggers_result = self._handle_results( + triggers_success, triggers_result, query_msg=queries[6] + ) + self.types_result = self._handle_results( + types_success, types_result, query_msg=queries[3] + ) + self.functions_result = self._handle_results( + functions_success, functions_result, query_msg=queries[4] + ) + self.aggregates_result = self._handle_results( + aggregates_success, aggregates_result, query_msg=queries[5] + ) + self.indexes_result = self._handle_results( + indexes_success, indexes_result, query_msg=queries[7] + ) + self.views_result = self._handle_results( + views_success, views_result, query_msg=queries[8] + ) # V4-only results # These tables don't exist in some DSE versions reporting 4.X so we can # ignore them if we got an error self.virtual_keyspaces_result = self._handle_results( - virtual_ks_success, virtual_ks_result, - expected_failures=(InvalidRequest,), query_msg=queries[9] + virtual_ks_success, + virtual_ks_result, + expected_failures=(InvalidRequest,), + query_msg=queries[9], ) self.virtual_tables_result = self._handle_results( - virtual_table_success, virtual_table_result, - expected_failures=(InvalidRequest,), query_msg=queries[10] + virtual_table_success, + virtual_table_result, + expected_failures=(InvalidRequest,), + query_msg=queries[10], ) self.virtual_columns_result = self._handle_results( - virtual_column_success, virtual_column_result, - expected_failures=(InvalidRequest,), query_msg=queries[11] + virtual_column_success, + virtual_column_result, + expected_failures=(InvalidRequest,), + query_msg=queries[11], ) self._aggregate_results() @@ -2986,7 +3833,7 @@ def _aggregate_results(self): m = self.virtual_columns_rows for row in self.virtual_columns_result: - ks_name = row['keyspace_name'] + ks_name = row["keyspace_name"] tab_name = row[self._table_name_col] m[ks_name][tab_name].append(row) @@ -2995,7 +3842,7 @@ def get_all_keyspaces(self): yield x for row in self.virtual_keyspaces_result: - ks_name = row['keyspace_name'] + ks_name = row["keyspace_name"] keyspace_meta = self._build_keyspace_metadata(row) keyspace_meta.virtual = True @@ -3004,9 +3851,9 @@ def get_all_keyspaces(self): col_rows = self.virtual_columns_rows[ks_name][table_name] keyspace_meta._add_table_metadata( - self._build_table_metadata(table_row, - col_rows=col_rows, - virtual=True) + self._build_table_metadata( + table_row, col_rows=col_rows, virtual=True + ) ) yield keyspace_meta @@ -3016,15 +3863,17 @@ def _build_keyspace_metadata_internal(row): row["durable_writes"] = row.get("durable_writes", None) row["replication"] = row.get("replication", {}) row["replication"]["class"] = row["replication"].get("class", None) - return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) + return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal( + row + ) class SchemaParserDSE67(SchemaParserV4): """ For DSE 6.7+ """ - recognized_table_options = (SchemaParserV4.recognized_table_options + - ("nodesync",)) + + recognized_table_options = SchemaParserV4.recognized_table_options + ("nodesync",) class SchemaParserDSE68(SchemaParserDSE67): @@ -3038,7 +3887,9 @@ class SchemaParserDSE68(SchemaParserDSE67): _table_metadata_class = TableMetadataDSE68 def __init__(self, connection, timeout, fetch_size, metadata_request_timeout): - super(SchemaParserDSE68, self).__init__(connection, timeout, fetch_size, metadata_request_timeout) + super(SchemaParserDSE68, self).__init__( + connection, timeout, fetch_size, metadata_request_timeout + ) self.keyspace_table_vertex_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_table_edge_rows = defaultdict(lambda: defaultdict(list)) @@ -3048,33 +3899,52 @@ def get_all_keyspaces(self): yield keyspace_meta def get_table(self, keyspaces, keyspace, table): - table_meta = super(SchemaParserDSE68, self).get_table(keyspaces, keyspace, table) + table_meta = super(SchemaParserDSE68, self).get_table( + keyspaces, keyspace, table + ) cl = ConsistencyLevel.ONE - where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) + where_clause = bind_params( + " WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), + (keyspace, table), + _encoder, + ) vertices_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_VERTICES + where_clause, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_VERTICES + where_clause, self.metadata_request_timeout + ), consistency_level=cl, ) edges_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_EDGES + where_clause, self.metadata_request_timeout), + query=maybe_add_timeout_to_query( + self._SELECT_EDGES + where_clause, self.metadata_request_timeout + ), consistency_level=cl, ) - (vertices_success, vertices_result), (edges_success, edges_result) \ - = self.connection.wait_for_responses(vertices_query, edges_query, timeout=self.timeout, fail_on_error=False) + (vertices_success, vertices_result), (edges_success, edges_result) = ( + self.connection.wait_for_responses( + vertices_query, edges_query, timeout=self.timeout, fail_on_error=False + ) + ) vertices_result = self._handle_results(vertices_success, vertices_result) edges_result = self._handle_results(edges_success, edges_result) try: if vertices_result: - table_meta.vertex = self._build_table_vertex_metadata(vertices_result[0]) + table_meta.vertex = self._build_table_vertex_metadata( + vertices_result[0] + ) elif edges_result: - table_meta.edge = self._build_table_edge_metadata(keyspaces[keyspace], edges_result[0]) + table_meta.edge = self._build_table_edge_metadata( + keyspaces[keyspace], edges_result[0] + ) except Exception: table_meta.vertex = None table_meta.edge = None table_meta._exc_info = sys.exc_info() - log.exception("Error while parsing graph metadata for table %s.%s.", keyspace, table) + log.exception( + "Error while parsing graph metadata for table %s.%s.", keyspace, table + ) return table_meta @@ -3082,41 +3952,56 @@ def get_table(self, keyspaces, keyspace, table): def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row.get("durable_writes", None) - replication = dict(row.get("replication")) if 'replication' in row else {} - replication_class = replication.pop("class") if 'class' in replication else None + replication = dict(row.get("replication")) if "replication" in row else {} + replication_class = replication.pop("class") if "class" in replication else None graph_engine = row.get("graph_engine", None) - return KeyspaceMetadata(name, durable_writes, replication_class, replication, graph_engine) + return KeyspaceMetadata( + name, durable_writes, replication_class, replication, graph_engine + ) def _build_graph_metadata(self, keyspace_meta): def _build_table_graph_metadata(table_meta): - for row in self.keyspace_table_vertex_rows[keyspace_meta.name][table_meta.name]: + for row in self.keyspace_table_vertex_rows[keyspace_meta.name][ + table_meta.name + ]: table_meta.vertex = self._build_table_vertex_metadata(row) - for row in self.keyspace_table_edge_rows[keyspace_meta.name][table_meta.name]: + for row in self.keyspace_table_edge_rows[keyspace_meta.name][ + table_meta.name + ]: table_meta.edge = self._build_table_edge_metadata(keyspace_meta, row) try: # Make sure we process vertices before edges - for table_meta in [t for t in keyspace_meta.tables.values() - if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + for table_meta in [ + t + for t in keyspace_meta.tables.values() + if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name] + ]: _build_table_graph_metadata(table_meta) # all other tables... - for table_meta in [t for t in keyspace_meta.tables.values() - if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + for table_meta in [ + t + for t in keyspace_meta.tables.values() + if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name] + ]: _build_table_graph_metadata(table_meta) except Exception: # schema error, remove all graph metadata for this keyspace for t in keyspace_meta.tables.values(): t.edge = t.vertex = None keyspace_meta._exc_info = sys.exc_info() - log.exception("Error while parsing graph metadata for keyspace %s", keyspace_meta.name) + log.exception( + "Error while parsing graph metadata for keyspace %s", keyspace_meta.name + ) @staticmethod def _build_table_vertex_metadata(row): - return VertexMetadata(row.get("keyspace_name"), row.get("table_name"), - row.get("label_name")) + return VertexMetadata( + row.get("keyspace_name"), row.get("table_name"), row.get("label_name") + ) @staticmethod def _build_table_edge_metadata(keyspace_meta, row): @@ -3128,37 +4013,113 @@ def _build_table_edge_metadata(keyspace_meta, row): to_label = to_table_meta.vertex.label_name return EdgeMetadata( - row.get("keyspace_name"), row.get("table_name"), - row.get("label_name"), from_table, from_label, + row.get("keyspace_name"), + row.get("table_name"), + row.get("label_name"), + from_table, + from_label, row.get("from_partition_key_columns"), - row.get("from_clustering_columns"), to_table, to_label, + row.get("from_clustering_columns"), + to_table, + to_label, row.get("to_partition_key_columns"), - row.get("to_clustering_columns")) + row.get("to_clustering_columns"), + ) def _query_all(self): cl = ConsistencyLevel.ONE queries = [ # copied from v4 - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_KEYSPACES, self.metadata_request_timeout), - consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TABLES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_COLUMNS, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TYPES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_FUNCTIONS, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_KEYSPACES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_TABLES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIRTUAL_COLUMNS, self.metadata_request_timeout), consistency_level=cl), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_KEYSPACES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TABLES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_COLUMNS, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TYPES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_FUNCTIONS, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_AGGREGATES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_TRIGGERS, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_INDEXES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIEWS, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_KEYSPACES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_TABLES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VIRTUAL_COLUMNS, self.metadata_request_timeout + ), + consistency_level=cl, + ), # dse6.8 only - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VERTICES, self.metadata_request_timeout), consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_EDGES, self.metadata_request_timeout), consistency_level=cl) + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_VERTICES, self.metadata_request_timeout + ), + consistency_level=cl, + ), + QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_EDGES, self.metadata_request_timeout + ), + consistency_level=cl, + ), ] responses = self.connection.wait_for_responses( - *queries, timeout=self.timeout, fail_on_error=False) + *queries, timeout=self.timeout, fail_on_error=False + ) ( # copied from V4 (ks_success, ks_result), @@ -3175,7 +4136,7 @@ def _query_all(self): (virtual_column_success, virtual_column_result), # dse6.8 responses (vertices_success, vertices_result), - (edges_success, edges_result) + (edges_success, edges_result), ) = responses # copied from V4 @@ -3184,24 +4145,29 @@ def _query_all(self): self.columns_result = self._handle_results(col_success, col_result) self.triggers_result = self._handle_results(triggers_success, triggers_result) self.types_result = self._handle_results(types_success, types_result) - self.functions_result = self._handle_results(functions_success, functions_result) - self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.functions_result = self._handle_results( + functions_success, functions_result + ) + self.aggregates_result = self._handle_results( + aggregates_success, aggregates_result + ) self.indexes_result = self._handle_results(indexes_success, indexes_result) self.views_result = self._handle_results(views_success, views_result) # These tables don't exist in some DSE versions reporting 4.X so we can # ignore them if we got an error self.virtual_keyspaces_result = self._handle_results( - virtual_ks_success, virtual_ks_result, - expected_failures=(InvalidRequest,) + virtual_ks_success, virtual_ks_result, expected_failures=(InvalidRequest,) ) self.virtual_tables_result = self._handle_results( - virtual_table_success, virtual_table_result, - expected_failures=(InvalidRequest,) + virtual_table_success, + virtual_table_result, + expected_failures=(InvalidRequest,), ) self.virtual_columns_result = self._handle_results( - virtual_column_success, virtual_column_result, - expected_failures=(InvalidRequest,) + virtual_column_success, + virtual_column_result, + expected_failures=(InvalidRequest,), ) # dse6.8-only results @@ -3216,13 +4182,13 @@ def _aggregate_results(self): m = self.keyspace_table_vertex_rows for row in self.vertices_result: ksname = row["keyspace_name"] - cfname = row['table_name'] + cfname = row["table_name"] m[ksname][cfname].append(row) m = self.keyspace_table_edge_rows for row in self.edges_result: ksname = row["keyspace_name"] - cfname = row['table_name'] + cfname = row["table_name"] m[ksname][cfname].append(row) @@ -3278,7 +4244,15 @@ class MaterializedViewMetadata(object): Metadata describing configuration for table extensions """ - def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): + def __init__( + self, + keyspace_name, + view_name, + base_table_name, + include_all_columns, + where_clause, + options, + ): self.keyspace_name = keyspace_name self.name = view_name self.base_table_name = base_table_name @@ -3295,35 +4269,47 @@ def as_cql_query(self, formatted=False): If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ - sep = '\n ' if formatted else ' ' + sep = "\n " if formatted else " " keyspace = protect_name(self.keyspace_name) name = protect_name(self.name) - selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values()) + selected_cols = ( + "*" + if self.include_all_columns + else ", ".join(protect_name(col.name) for col in self.columns.values()) + ) base_table = protect_name(self.base_table_name) where_clause = self.where_clause - part_key = ', '.join(protect_name(col.name) for col in self.partition_key) + part_key = ", ".join(protect_name(col.name) for col in self.partition_key) if len(self.partition_key) > 1: pk = "((%s)" % part_key else: pk = "(%s" % part_key if self.clustering_key: - pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key) + pk += ", %s" % ", ".join( + protect_name(col.name) for col in self.clustering_key + ) pk += ")" - properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) + properties = TableMetadataV3._property_string( + formatted, self.clustering_key, self.options + ) - ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" - "SELECT %(selected_cols)s%(sep)s" - "FROM %(keyspace)s.%(base_table)s%(sep)s" - "WHERE %(where_clause)s%(sep)s" - "PRIMARY KEY %(pk)s%(sep)s" - "WITH %(properties)s") % locals() + ret = ( + "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" + "SELECT %(selected_cols)s%(sep)s" + "FROM %(keyspace)s.%(base_table)s%(sep)s" + "WHERE %(where_clause)s%(sep)s" + "PRIMARY KEY %(pk)s%(sep)s" + "WITH %(properties)s" + ) % locals() if self.extensions: registry = _RegisteredExtensionType._extension_registry - for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + for k in ( + registry.keys() & self.extensions + ): # no viewkeys on OrderedMapSerializeKey ext = registry[k] cql = ext.after_table_cql(self, k, self.extensions[k]) if cql: @@ -3393,10 +4379,19 @@ class EdgeMetadata(object): """The columns that match the clustering columns of the outgoing vertex table.""" def __init__( - self, keyspace_name, table_name, label_name, from_table, - from_label, from_partition_key_columns, from_clustering_columns, - to_table, to_label, to_partition_key_columns, - to_clustering_columns): + self, + keyspace_name, + table_name, + label_name, + from_table, + from_label, + from_partition_key_columns, + from_clustering_columns, + to_table, + to_label, + to_partition_key_columns, + to_clustering_columns, + ): self.keyspace_name = keyspace_name self.table_name = table_name self.label_name = label_name @@ -3410,14 +4405,20 @@ def __init__( self.to_clustering_columns = to_clustering_columns -def get_column_from_system_local(connection, column_name: str, timeout, metadata_request_timeout) -> str: +def get_column_from_system_local( + connection, column_name: str, timeout, metadata_request_timeout +) -> str: success, local_result = connection.wait_for_response( QueryMessage( query=maybe_add_timeout_to_query( "SELECT " + column_name + " FROM system.local WHERE key='local'", - metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - , timeout=timeout, fail_on_error=False) + metadata_request_timeout, + ), + consistency_level=ConsistencyLevel.ONE, + ), + timeout=timeout, + fail_on_error=False, + ) if not success or not local_result.parsed_rows: return "" local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) @@ -3425,29 +4426,48 @@ def get_column_from_system_local(connection, column_name: str, timeout, metadata return local_row.get(column_name) -def get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size=None): +def get_schema_parser( + connection, + server_version, + dse_version, + timeout, + metadata_request_timeout, + fetch_size=None, +): if server_version is None and dse_version is None: - server_version = get_column_from_system_local(connection, "release_version", timeout, metadata_request_timeout) - dse_version = get_column_from_system_local(connection, "dse_version", timeout, metadata_request_timeout) + server_version = get_column_from_system_local( + connection, "release_version", timeout, metadata_request_timeout + ) + dse_version = get_column_from_system_local( + connection, "dse_version", timeout, metadata_request_timeout + ) version = Version(server_version or "0") if dse_version: v = Version(dse_version) - if v >= Version('6.8.0'): - return SchemaParserDSE68(connection, timeout, fetch_size, metadata_request_timeout) - elif v >= Version('6.7.0'): - return SchemaParserDSE67(connection, timeout, fetch_size, metadata_request_timeout) - elif v >= Version('6.0.0'): - return SchemaParserDSE60(connection, timeout, fetch_size, metadata_request_timeout) - - if version >= Version('4-a'): + if v >= Version("6.8.0"): + return SchemaParserDSE68( + connection, timeout, fetch_size, metadata_request_timeout + ) + elif v >= Version("6.7.0"): + return SchemaParserDSE67( + connection, timeout, fetch_size, metadata_request_timeout + ) + elif v >= Version("6.0.0"): + return SchemaParserDSE60( + connection, timeout, fetch_size, metadata_request_timeout + ) + + if version >= Version("4-a"): return SchemaParserV4(connection, timeout, fetch_size, metadata_request_timeout) - elif version >= Version('3.0.0'): + elif version >= Version("3.0.0"): return SchemaParserV3(connection, timeout, fetch_size, metadata_request_timeout) else: # we could further specialize by version. Right now just refactoring the # multi-version parser we have as of C* 2.2.0rc1. - return SchemaParserV22(connection, timeout, fetch_size, metadata_request_timeout) + return SchemaParserV22( + connection, timeout, fetch_size, metadata_request_timeout + ) def _cql_from_cass_type(cass_type): @@ -3466,9 +4486,13 @@ class RLACTableExtension(RegisteredTableExtension): @classmethod def after_table_cql(cls, table_meta, ext_key, ext_blob): - return "RESTRICT ROWS ON %s.%s USING %s;" % (protect_name(table_meta.keyspace_name), - protect_name(table_meta.name), - protect_name(ext_blob.decode('utf-8'))) + return "RESTRICT ROWS ON %s.%s USING %s;" % ( + protect_name(table_meta.keyspace_name), + protect_name(table_meta.name), + protect_name(ext_blob.decode("utf-8")), + ) + + NO_VALID_REPLICA = object() @@ -3490,21 +4514,31 @@ def group_keys_by_replica(session, keyspace, table, keys): partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key - serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys) + serializers = list( + types._cqltypes[partition_key.cql_type] for partition_key in partition_keys + ) keys_per_host = defaultdict(list) distance = cluster._default_load_balancing_policy.distance for key in keys: - serialized_key = [serializer.serialize(pk, cluster.protocol_version) - for serializer, pk in zip(serializers, key)] + serialized_key = [ + serializer.serialize(pk, cluster.protocol_version) + for serializer, pk in zip(serializers, key) + ] if len(serialized_key) == 1: routing_key = serialized_key[0] else: - routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key) + routing_key = b"".join( + struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key + ) all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) # First check if there are local replicas - valid_replicas = [host for host in all_replicas if - host.is_up and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]] + valid_replicas = [ + host + for host in all_replicas + if host.is_up + and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK] + ] if not valid_replicas: valid_replicas = [host for host in all_replicas if host.is_up] From 4adfaf0f1e535d35a2f546d4c0c892e9de52b02a Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:00:59 +0200 Subject: [PATCH 04/18] remove: dead workaround for CPython bug #10923 (fixed in 3.3) The module-level "".encode("utf8") call was a workaround for CPython bug #10923, where importing the utf8 codec for the first time in a background thread could cause a deadlock due to the import lock. This bug was fixed in CPython 3.3 (2012), and the driver now requires Python 3.9+. The workaround is dead code that serves no purpose and confuses readers. --- cassandra/cluster.py | 2540 ++++++++++++++++++++++++++++++------------ 1 file changed, 1808 insertions(+), 732 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 622b706330..569bb578f1 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -16,6 +16,7 @@ This module houses the main classes you will interact with, :class:`.Cluster` and :class:`.Session`. """ + from __future__ import absolute_import import atexit @@ -43,57 +44,120 @@ import weakref from weakref import WeakValueDictionary -from cassandra import (ConsistencyLevel, AuthenticationFailed, InvalidRequest, - OperationTimedOut, UnsupportedOperation, - SchemaTargetType, DriverException, ProtocolVersion, - UnresolvableContactPoints, DependencyException) +from cassandra import ( + ConsistencyLevel, + AuthenticationFailed, + InvalidRequest, + OperationTimedOut, + UnsupportedOperation, + SchemaTargetType, + DriverException, + ProtocolVersion, + UnresolvableContactPoints, + DependencyException, +) from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider -from cassandra.connection import (ConnectionException, ConnectionShutdown, - ConnectionHeartbeat, ProtocolVersionUnsupported, - EndPoint, DefaultEndPoint, DefaultEndPointFactory, - SniEndPointFactory, ConnectionBusy, locally_supported_compressions) +from cassandra.connection import ( + ConnectionException, + ConnectionShutdown, + ConnectionHeartbeat, + ProtocolVersionUnsupported, + EndPoint, + DefaultEndPoint, + DefaultEndPointFactory, + SniEndPointFactory, + ConnectionBusy, + locally_supported_compressions, +) from cassandra.cqltypes import UserType import cassandra.cqltypes as types from cassandra.encoder import Encoder -from cassandra.protocol import (QueryMessage, ResultMessage, - ErrorMessage, ReadTimeoutErrorMessage, - WriteTimeoutErrorMessage, - UnavailableErrorMessage, - OverloadedErrorMessage, - PrepareMessage, ExecuteMessage, - PreparedQueryNotFound, - IsBootstrappingErrorMessage, - TruncateError, ServerError, - BatchMessage, RESULT_KIND_PREPARED, - RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, - RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler, - RESULT_KIND_VOID, ProtocolException) +from cassandra.protocol import ( + QueryMessage, + ResultMessage, + ErrorMessage, + ReadTimeoutErrorMessage, + WriteTimeoutErrorMessage, + UnavailableErrorMessage, + OverloadedErrorMessage, + PrepareMessage, + ExecuteMessage, + PreparedQueryNotFound, + IsBootstrappingErrorMessage, + TruncateError, + ServerError, + BatchMessage, + RESULT_KIND_PREPARED, + RESULT_KIND_SET_KEYSPACE, + RESULT_KIND_ROWS, + RESULT_KIND_SCHEMA_CHANGE, + ProtocolHandler, + RESULT_KIND_VOID, + ProtocolException, +) from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo -from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, - ExponentialReconnectionPolicy, HostDistance, - RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, - NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy, - NeverRetryPolicy) -from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, - HostConnection, - NoConnectionsAvailable) -from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, - BatchStatement, bind_params, QueryTrace, TraceUnavailable, - named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET, - HostTargetingStatement) +from cassandra.policies import ( + TokenAwarePolicy, + DCAwareRoundRobinPolicy, + SimpleConvictionPolicy, + ExponentialReconnectionPolicy, + HostDistance, + RetryPolicy, + IdentityTranslator, + NoSpeculativeExecutionPlan, + NoSpeculativeExecutionPolicy, + DefaultLoadBalancingPolicy, + NeverRetryPolicy, +) +from cassandra.pool import ( + Host, + _ReconnectionHandler, + _HostReconnectionHandler, + HostConnection, + NoConnectionsAvailable, +) +from cassandra.query import ( + SimpleStatement, + PreparedStatement, + BoundStatement, + BatchStatement, + bind_params, + QueryTrace, + TraceUnavailable, + named_tuple_factory, + dict_factory, + tuple_factory, + FETCH_SIZE_UNSET, + HostTargetingStatement, +) from cassandra.marshal import int64_pack from cassandra.tablets import Tablet, Tablets from cassandra.timestamps import MonotonicTimestampGenerator -from cassandra.util import _resolve_contact_points_to_string_map, Version, maybe_add_timeout_to_query +from cassandra.util import ( + _resolve_contact_points_to_string_map, + Version, + maybe_add_timeout_to_query, +) from cassandra.datastax.insights.reporter import MonitorReporter from cassandra.datastax.insights.util import version_supports_insights -from cassandra.datastax.graph import (graph_object_row_factory, GraphOptions, GraphSON1Serializer, - GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement, - graph_graphson2_row_factory, graph_graphson3_row_factory, - GraphSON3Serializer) -from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory +from cassandra.datastax.graph import ( + graph_object_row_factory, + GraphOptions, + GraphSON1Serializer, + GraphProtocol, + GraphSON2Serializer, + GraphStatement, + SimpleGraphStatement, + graph_graphson2_row_factory, + graph_graphson3_row_factory, + GraphSON3Serializer, +) +from cassandra.datastax.graph.query import ( + _request_timeout_key, + _GraphSONContextRowFactory, +) from cassandra.datastax import cloud as dscloud from cassandra.application_info import ApplicationInfoBase @@ -114,59 +178,76 @@ except ImportError: from cassandra.util import WeakSet # NOQA + def _is_gevent_monkey_patched(): - if 'gevent.monkey' not in sys.modules: + if "gevent.monkey" not in sys.modules: return False try: import gevent.socket - return socket.socket is gevent.socket.socket # Another case related to PYTHON-1364 + + return ( + socket.socket is gevent.socket.socket + ) # Another case related to PYTHON-1364 except (AttributeError, ImportError): return False + def _try_gevent_import(): if _is_gevent_monkey_patched(): from cassandra.io.geventreactor import GeventConnection - return (GeventConnection,None) + + return (GeventConnection, None) else: - return (None,None) + return (None, None) + def _is_eventlet_monkey_patched(): - if 'eventlet.patcher' not in sys.modules: + if "eventlet.patcher" not in sys.modules: return False try: import eventlet.patcher - return eventlet.patcher.is_monkey_patched('socket') + + return eventlet.patcher.is_monkey_patched("socket") except (ImportError, AttributeError): # AttributeError was add for handling python 3.12 https://github.com/eventlet/eventlet/issues/812 # TODO: remove it when eventlet issue would be fixed return False + def _try_eventlet_import(): if _is_eventlet_monkey_patched(): from cassandra.io.eventletreactor import EventletConnection - return (EventletConnection,None) + + return (EventletConnection, None) else: - return (None,None) + return (None, None) + def _try_libev_import(): try: from cassandra.io.libevreactor import LibevConnection - return (LibevConnection,None) + + return (LibevConnection, None) except DependencyException as e: return (None, e) + def _try_asyncore_import(): try: from cassandra.io.asyncorereactor import AsyncoreConnection - return (AsyncoreConnection,None) + + return (AsyncoreConnection, None) except DependencyException as e: return (None, e) + def _try_asyncio_import(): from cassandra.io.asyncioreactor import AsyncioConnection + return (AsyncioConnection, None) -def _connection_reduce_fn(val,import_fn): + +def _connection_reduce_fn(val, import_fn): (rv, excs) = val # If we've already found a workable Connection class return immediately if rv: @@ -176,22 +257,23 @@ def _connection_reduce_fn(val,import_fn): excs.append(exc) return (rv or import_result, excs) -conn_fns = (_try_gevent_import, _try_eventlet_import, _try_libev_import, _try_asyncore_import, _try_asyncio_import) -(conn_class, excs) = reduce(_connection_reduce_fn, conn_fns, (None,[])) + +conn_fns = ( + _try_gevent_import, + _try_eventlet_import, + _try_libev_import, + _try_asyncore_import, + _try_asyncio_import, +) +(conn_class, excs) = reduce(_connection_reduce_fn, conn_fns, (None, [])) if not conn_class: raise DependencyException("Exception loading connection class dependencies", excs) DefaultConnection = conn_class -# Forces load of utf8 encoding module to avoid deadlock that occurs -# if code that is being imported tries to import the module in a seperate -# thread. -# See http://bugs.python.org/issue10923 -"".encode('utf8') - log = logging.getLogger(__name__) -_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') +_GRAPH_PAGING_MIN_DSE_VERSION = Version("6.8.0") _NOT_SET = object() @@ -215,7 +297,7 @@ def __init__(self, message, errors): def _future_completed(future): - """ Helper for run_in_executor() """ + """Helper for run_in_executor()""" exc = future.exception() if exc: log.debug("Failed to run task on executor", exc_info=exc) @@ -252,7 +334,9 @@ def _discard_cluster_shutdown(cluster): def _shutdown_clusters(): - clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" + clusters = ( + _clusters_for_shutdown.copy() + ) # copy because shutdown modifies the global set "discard" for cluster in clusters: cluster.shutdown() @@ -267,7 +351,6 @@ def default_lbp_factory(): class ContinuousPagingOptions(object): - class PagingUnit(object): BYTES = 1 ROWS = 2 @@ -295,12 +378,20 @@ class PagingUnit(object): by default it is 4 and it must be at least 2. """ - def __init__(self, page_unit=PagingUnit.ROWS, max_pages=0, max_pages_per_second=0, max_queue_size=4): + def __init__( + self, + page_unit=PagingUnit.ROWS, + max_pages=0, + max_pages_per_second=0, + max_queue_size=4, + ): self.page_unit = page_unit self.max_pages = max_pages self.max_pages_per_second = max_pages_per_second if max_queue_size < 2: - raise ValueError('ContinuousPagingOptions.max_queue_size must be 2 or greater') + raise ValueError( + "ContinuousPagingOptions.max_queue_size must be 2 or greater" + ) self.max_queue_size = max_queue_size def page_unit_bytes(self): @@ -314,20 +405,22 @@ def _addrinfo_or_none(contact_point, port): PYTHON-895. """ try: - return socket.getaddrinfo(contact_point, port, - socket.AF_UNSPEC, socket.SOCK_STREAM) + return socket.getaddrinfo( + contact_point, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) except socket.gaierror: - log.debug('Could not resolve hostname "{}" ' - 'with port {}'.format(contact_point, port)) + log.debug( + 'Could not resolve hostname "{}" with port {}'.format(contact_point, port) + ) return None def _execution_profile_to_string(name): default_profiles = { - EXEC_PROFILE_DEFAULT: 'EXEC_PROFILE_DEFAULT', - EXEC_PROFILE_GRAPH_DEFAULT: 'EXEC_PROFILE_GRAPH_DEFAULT', - EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT: 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT', - EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT', + EXEC_PROFILE_DEFAULT: "EXEC_PROFILE_DEFAULT", + EXEC_PROFILE_GRAPH_DEFAULT: "EXEC_PROFILE_GRAPH_DEFAULT", + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT: "EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT", + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: "EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT", } if name in default_profiles: @@ -407,10 +500,17 @@ class ExecutionProfile(object): _load_balancing_policy_explicit = False _consistency_level_explicit = False - def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, - consistency_level=_NOT_SET, serial_consistency_level=None, - request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None, - continuous_paging_options=None): + def __init__( + self, + load_balancing_policy=_NOT_SET, + retry_policy=None, + consistency_level=_NOT_SET, + serial_consistency_level=None, + request_timeout=10.0, + row_factory=named_tuple_factory, + speculative_execution_policy=None, + continuous_paging_options=None, + ): if load_balancing_policy is _NOT_SET: self._load_balancing_policy_explicit = False @@ -428,16 +528,21 @@ def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, self.retry_policy = retry_policy or RetryPolicy() - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): - raise ValueError("serial_consistency_level must be either " - "ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL.") + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): + raise ValueError( + "serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL." + ) self.serial_consistency_level = serial_consistency_level self.request_timeout = request_timeout self.row_factory = row_factory - self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() + self.speculative_execution_policy = ( + speculative_execution_policy or NoSpeculativeExecutionPolicy() + ) self.continuous_paging_options = continuous_paging_options @@ -453,10 +558,17 @@ class GraphExecutionProfile(ExecutionProfile): See cassandra.graph.GraphOptions """ - def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, - consistency_level=_NOT_SET, serial_consistency_level=None, - request_timeout=30.0, row_factory=None, - graph_options=None, continuous_paging_options=_NOT_SET): + def __init__( + self, + load_balancing_policy=_NOT_SET, + retry_policy=None, + consistency_level=_NOT_SET, + serial_consistency_level=None, + request_timeout=30.0, + row_factory=None, + graph_options=None, + continuous_paging_options=_NOT_SET, + ): """ Default execution profile for graph execution. @@ -469,19 +581,31 @@ def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, :class:`cassandra.policies.NeverRetryPolicy`. """ retry_policy = retry_policy or NeverRetryPolicy() - super(GraphExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, - serial_consistency_level, request_timeout, row_factory, - continuous_paging_options=continuous_paging_options) - self.graph_options = graph_options or GraphOptions(graph_source=b'g', - graph_language=b'gremlin-groovy') + super(GraphExecutionProfile, self).__init__( + load_balancing_policy, + retry_policy, + consistency_level, + serial_consistency_level, + request_timeout, + row_factory, + continuous_paging_options=continuous_paging_options, + ) + self.graph_options = graph_options or GraphOptions( + graph_source=b"g", graph_language=b"gremlin-groovy" + ) class GraphAnalyticsExecutionProfile(GraphExecutionProfile): - - def __init__(self, load_balancing_policy=None, retry_policy=None, - consistency_level=_NOT_SET, serial_consistency_level=None, - request_timeout=3600. * 24. * 7., row_factory=None, - graph_options=None): + def __init__( + self, + load_balancing_policy=None, + retry_policy=None, + consistency_level=_NOT_SET, + serial_consistency_level=None, + request_timeout=3600.0 * 24.0 * 7.0, + row_factory=None, + graph_options=None, + ): """ Execution profile with timeout and load balancing appropriate for graph analytics queries. @@ -494,35 +618,50 @@ def __init__(self, load_balancing_policy=None, retry_policy=None, Note: The graph_options.graph_source is set automatically to b'a' (analytics) when using GraphAnalyticsExecutionProfile. This is mandatory to target analytics nodes. """ - load_balancing_policy = load_balancing_policy or DefaultLoadBalancingPolicy(default_lbp_factory()) - graph_options = graph_options or GraphOptions(graph_language=b'gremlin-groovy') - super(GraphAnalyticsExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, - serial_consistency_level, request_timeout, row_factory, - graph_options) + load_balancing_policy = load_balancing_policy or DefaultLoadBalancingPolicy( + default_lbp_factory() + ) + graph_options = graph_options or GraphOptions(graph_language=b"gremlin-groovy") + super(GraphAnalyticsExecutionProfile, self).__init__( + load_balancing_policy, + retry_policy, + consistency_level, + serial_consistency_level, + request_timeout, + row_factory, + graph_options, + ) # ensure the graph_source is analytics, since this is the purpose of the GraphAnalyticsExecutionProfile self.graph_options.set_source_analytics() class ProfileManager(object): - def __init__(self): self.profiles = dict() def _profiles_without_explicit_lbps(self): - names = (profile_name for - profile_name, profile in self.profiles.items() - if not profile._load_balancing_policy_explicit) + names = ( + profile_name + for profile_name, profile in self.profiles.items() + if not profile._load_balancing_policy_explicit + ) return tuple( - 'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n - for n in names + "EXEC_PROFILE_DEFAULT" if n is EXEC_PROFILE_DEFAULT else n for n in names ) def distance(self, host): - distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) - return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ - HostDistance.LOCAL if HostDistance.LOCAL in distances else \ - HostDistance.REMOTE if HostDistance.REMOTE in distances else \ - HostDistance.IGNORED + distances = set( + p.load_balancing_policy.distance(host) for p in self.profiles.values() + ) + return ( + HostDistance.LOCAL_RACK + if HostDistance.LOCAL_RACK in distances + else HostDistance.LOCAL + if HostDistance.LOCAL in distances + else HostDistance.REMOTE + if HostDistance.REMOTE in distances + else HostDistance.IGNORED + ) def populate(self, cluster, hosts): for p in self.profiles.values(): @@ -629,7 +768,7 @@ class Cluster(object): which implicitly handle shutdown when leaving scope. """ - contact_points = ['127.0.0.1'] + contact_points = ["127.0.0.1"] """ The list of contact points to try connecting for cluster discovery. A contact point can be a string (ip or hostname), a tuple (ip/hostname, port) or a @@ -736,15 +875,20 @@ def auth_provider(self, value): self._auth_provider_callable = value.new_authenticator except AttributeError: if self.protocol_version > 1: - raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " - "interface when protocol_version >= 2") + raise TypeError( + "auth_provider must implement the cassandra.auth.AuthProvider " + "interface when protocol_version >= 2" + ) elif not callable(value): - raise TypeError("auth_provider must be callable when protocol_version == 1") + raise TypeError( + "auth_provider must be callable when protocol_version == 1" + ) self._auth_provider_callable = value self._auth_provider = value _load_balancing_policy = None + @property def load_balancing_policy(self): """ @@ -765,7 +909,9 @@ def load_balancing_policy(self): @load_balancing_policy.setter def load_balancing_policy(self, lbp): if self._config_mode == _ConfigMode.PROFILES: - raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") + raise ValueError( + "Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead." + ) self._load_balancing_policy = lbp self._config_mode = _ConfigMode.LEGACY @@ -781,6 +927,7 @@ def _default_load_balancing_policy(self): """ _default_retry_policy = RetryPolicy() + @property def default_retry_policy(self): """ @@ -793,7 +940,9 @@ def default_retry_policy(self): @default_retry_policy.setter def default_retry_policy(self, policy): if self._config_mode == _ConfigMode.PROFILES: - raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") + raise ValueError( + "Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead." + ) self._default_retry_policy = policy self._config_mode = _ConfigMode.LEGACY @@ -1044,12 +1193,12 @@ def default_retry_policy(self, policy): be generated automatically unless the user provides one. """ - application_name = '' + application_name = "" """ A string identifying this application to Insights. """ - application_version = '' + application_version = "" """ A string identifiying this application's version to Insights """ @@ -1169,54 +1318,55 @@ def token_metadata_enabled(self, enabled): _listeners = None _listener_lock = None - def __init__(self, - contact_points=_NOT_SET, - port=9042, - compression: Union[bool, str, None] = True, - auth_provider=None, - load_balancing_policy=None, - reconnection_policy=None, - default_retry_policy=None, - conviction_policy_factory=None, - metrics_enabled=False, - connection_class=None, - ssl_options=None, - sockopts=None, - cql_version=None, - protocol_version=_NOT_SET, - executor_threads=2, - max_schema_agreement_wait=10, - control_connection_timeout=2.0, - idle_heartbeat_interval=30, - schema_event_refresh_window=2, - topology_event_refresh_window=10, - connect_timeout=5, - schema_metadata_enabled=True, - token_metadata_enabled=True, - schema_metadata_page_size=1000, - address_translator=None, - status_event_refresh_window=2, - prepare_on_all_hosts=True, - reprepare_on_up=True, - execution_profiles=None, - allow_beta_protocol_version=False, - timestamp_generator=None, - idle_heartbeat_timeout=30, - no_compact=False, - ssl_context=None, - endpoint_factory=None, - application_name=None, - application_version=None, - monitor_reporting_enabled=True, - monitor_reporting_interval=30, - client_id=None, - cloud=None, - scylla_cloud=None, - shard_aware_options=None, - metadata_request_timeout: Optional[float] = None, - column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None - ): + def __init__( + self, + contact_points=_NOT_SET, + port=9042, + compression: Union[bool, str, None] = True, + auth_provider=None, + load_balancing_policy=None, + reconnection_policy=None, + default_retry_policy=None, + conviction_policy_factory=None, + metrics_enabled=False, + connection_class=None, + ssl_options=None, + sockopts=None, + cql_version=None, + protocol_version=_NOT_SET, + executor_threads=2, + max_schema_agreement_wait=10, + control_connection_timeout=2.0, + idle_heartbeat_interval=30, + schema_event_refresh_window=2, + topology_event_refresh_window=10, + connect_timeout=5, + schema_metadata_enabled=True, + token_metadata_enabled=True, + schema_metadata_page_size=1000, + address_translator=None, + status_event_refresh_window=2, + prepare_on_all_hosts=True, + reprepare_on_up=True, + execution_profiles=None, + allow_beta_protocol_version=False, + timestamp_generator=None, + idle_heartbeat_timeout=30, + no_compact=False, + ssl_context=None, + endpoint_factory=None, + application_name=None, + application_version=None, + monitor_reporting_enabled=True, + monitor_reporting_interval=30, + client_id=None, + cloud=None, + scylla_cloud=None, + shard_aware_options=None, + metadata_request_timeout: Optional[float] = None, + column_encryption_policy=None, + application_info: Optional[ApplicationInfoBase] = None, + ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as extablishing connection pools or refreshing metadata. @@ -1227,7 +1377,9 @@ def __init__(self, # Handle port passed as string if isinstance(port, str): if not port.isdigit(): - raise ValueError("Only numeric values are supported for port (%s)" % port) + raise ValueError( + "Only numeric values are supported for port (%s)" % port + ) port = int(port) if port < 1 or port > 65535: @@ -1237,25 +1389,47 @@ def __init__(self, self.connection_class = connection_class if scylla_cloud is not None: - raise NotImplementedError("scylla_cloud was removed and not supported anymore") + raise NotImplementedError( + "scylla_cloud was removed and not supported anymore" + ) if cloud is not None: self.cloud = cloud - if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options: - raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options " - "cannot be specified with a cloud configuration") + if ( + contact_points is not _NOT_SET + or endpoint_factory + or ssl_context + or ssl_options + ): + raise ValueError( + "contact_points, endpoint_factory, ssl_context, and ssl_options " + "cannot be specified with a cloud configuration" + ) - uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection) - uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection) - cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet) + uses_twisted = TwistedConnection and issubclass( + self.connection_class, TwistedConnection + ) + uses_eventlet = EventletConnection and issubclass( + self.connection_class, EventletConnection + ) + cloud_config = dscloud.get_cloud_config( + cloud, create_pyopenssl_context=uses_twisted or uses_eventlet + ) ssl_context = cloud_config.ssl_context - ssl_options = {'check_hostname': True} - if (auth_provider is None and cloud_config.username - and cloud_config.password): - auth_provider = PlainTextAuthProvider(cloud_config.username, cloud_config.password) + ssl_options = {"check_hostname": True} + if ( + auth_provider is None + and cloud_config.username + and cloud_config.password + ): + auth_provider = PlainTextAuthProvider( + cloud_config.username, cloud_config.password + ) - endpoint_factory = SniEndPointFactory(cloud_config.sni_host, cloud_config.sni_port) + endpoint_factory = SniEndPointFactory( + cloud_config.sni_host, cloud_config.sni_port + ) contact_points = [ endpoint_factory.create_from_sni(host_id) for host_id in cloud_config.host_ids @@ -1264,15 +1438,19 @@ def __init__(self, if contact_points is not None: if contact_points is _NOT_SET: self._contact_points_explicit = False - contact_points = ['127.0.0.1'] + contact_points = ["127.0.0.1"] else: self._contact_points_explicit = True if isinstance(contact_points, str): - raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") + raise TypeError( + "contact_points should not be a string, it should be a sequence (e.g. list) of strings" + ) if None in contact_points: - raise ValueError("contact_points should not contain None (it can resolve to localhost)") + raise ValueError( + "contact_points should not contain None (it can resolve to localhost)" + ) self.contact_points = contact_points self.port = port @@ -1280,7 +1458,9 @@ def __init__(self, if column_encryption_policy is not None: self.column_encryption_policy = column_encryption_policy - self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) + self.endpoint_factory = endpoint_factory or DefaultEndPointFactory( + port=self.port + ) self.endpoint_factory.configure(self) self._resolve_hostnames() @@ -1297,7 +1477,8 @@ def __init__(self, if not locally_supported_compressions.get(compression): raise ValueError( "Compression '%s' was requested, but it is not available. " - "Consider installing the corresponding Python package." % compression + "Consider installing the corresponding Python package." + % compression ) else: raise TypeError( @@ -1317,19 +1498,27 @@ def __init__(self, if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): - raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") + raise TypeError( + "load_balancing_policy should not be a class, it should be an instance of that class" + ) self.load_balancing_policy = load_balancing_policy else: - self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode + self._load_balancing_policy = ( + default_lbp_factory() + ) # set internal attribute to avoid committing to legacy config mode if reconnection_policy is not None: if isinstance(reconnection_policy, type): - raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") + raise TypeError( + "reconnection_policy should not be a class, it should be an instance of that class" + ) self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): - raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") + raise TypeError( + "default_retry_policy should not be a class, it should be an instance of that class" + ) self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: @@ -1339,13 +1528,16 @@ def __init__(self, if address_translator is not None: if isinstance(address_translator, type): - raise TypeError("address_translator should not be a class, it should be an instance of that class") + raise TypeError( + "address_translator should not be a class, it should be an instance of that class" + ) self.address_translator = address_translator if application_info is not None: if not isinstance(application_info, ApplicationInfoBase): raise TypeError( - "application_info should be an instance of any ApplicationInfoBase class") + "application_info should be an instance of any ApplicationInfoBase class" + ) self._application_info = application_info if timestamp_generator is not None: @@ -1360,18 +1552,23 @@ def __init__(self, self.load_balancing_policy, self.default_retry_policy, request_timeout=Session._default_timeout, - row_factory=Session._row_factory + row_factory=Session._row_factory, ) # legacy mode if either of these is not default if load_balancing_policy or default_retry_policy: if execution_profiles: - raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " - "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") + raise ValueError( + "Clusters constructed with execution_profiles should not specify legacy parameters " + "load_balancing_policy or default_retry_policy. Configure this in a profile instead." + ) self._config_mode = _ConfigMode.LEGACY - warn("Legacy execution parameters will be removed in 4.0. Consider using " - "execution profiles.", DeprecationWarning) + warn( + "Legacy execution parameters will be removed in 4.0. Consider using " + "execution profiles.", + DeprecationWarning, + ) else: profiles = self.profile_manager.profiles @@ -1379,44 +1576,63 @@ def __init__(self, profiles.update(execution_profiles) self._config_mode = _ConfigMode.PROFILES - lbp = DefaultLoadBalancingPolicy(self.profile_manager.default.load_balancing_policy) - profiles.setdefault(EXEC_PROFILE_GRAPH_DEFAULT, GraphExecutionProfile(load_balancing_policy=lbp)) - profiles.setdefault(EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, - GraphExecutionProfile(load_balancing_policy=lbp, request_timeout=60. * 3.)) - profiles.setdefault(EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, - GraphAnalyticsExecutionProfile(load_balancing_policy=lbp)) + lbp = DefaultLoadBalancingPolicy( + self.profile_manager.default.load_balancing_policy + ) + profiles.setdefault( + EXEC_PROFILE_GRAPH_DEFAULT, + GraphExecutionProfile(load_balancing_policy=lbp), + ) + profiles.setdefault( + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + GraphExecutionProfile( + load_balancing_policy=lbp, request_timeout=60.0 * 3.0 + ), + ) + profiles.setdefault( + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + GraphAnalyticsExecutionProfile(load_balancing_policy=lbp), + ) - if self._contact_points_explicit and not self.cloud: # avoid this warning for cloud users. + if ( + self._contact_points_explicit and not self.cloud + ): # avoid this warning for cloud users. if self._config_mode is _ConfigMode.PROFILES: - default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps() + default_lbp_profiles = ( + self.profile_manager._profiles_without_explicit_lbps() + ) if default_lbp_profiles: log.warning( - 'Cluster.__init__ called with contact_points ' - 'specified, but load-balancing policies are not ' - 'specified in some ExecutionProfiles. In the next ' - 'major version, this will raise an error; please ' - 'specify a load-balancing policy. ' - '(contact_points = {cp}, ' - 'EPs without explicit LBPs = {eps})' - ''.format(cp=contact_points, eps=default_lbp_profiles)) + "Cluster.__init__ called with contact_points " + "specified, but load-balancing policies are not " + "specified in some ExecutionProfiles. In the next " + "major version, this will raise an error; please " + "specify a load-balancing policy. " + "(contact_points = {cp}, " + "EPs without explicit LBPs = {eps})" + "".format(cp=contact_points, eps=default_lbp_profiles) + ) else: if load_balancing_policy is None: log.warning( - 'Cluster.__init__ called with contact_points ' - 'specified, but no load_balancing_policy. In the next ' - 'major version, this will raise an error; please ' - 'specify a load-balancing policy. ' - '(contact_points = {cp}, lbp = {lbp})' - ''.format(cp=contact_points, lbp=load_balancing_policy)) + "Cluster.__init__ called with contact_points " + "specified, but no load_balancing_policy. In the next " + "major version, this will raise an error; please " + "specify a load-balancing policy. " + "(contact_points = {cp}, lbp = {lbp})" + "".format(cp=contact_points, lbp=load_balancing_policy) + ) self.metrics_enabled = metrics_enabled if ssl_options and not ssl_context: - warn('Using ssl_options without ssl_context is ' - 'deprecated and will result in an error in ' - 'the next major release. Please use ssl_context ' - 'to prepare for that release.', - DeprecationWarning) + warn( + "Using ssl_options without ssl_context is " + "deprecated and will result in an error in " + "the next major release. Please use ssl_context " + "to prepare for that release.", + DeprecationWarning, + ) self.ssl_options = ssl_options self.ssl_context = ssl_context @@ -1424,7 +1640,11 @@ def __init__(self, self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout - self.metadata_request_timeout = self.control_connection_timeout if metadata_request_timeout is None else metadata_request_timeout + self.metadata_request_timeout = ( + self.control_connection_timeout + if metadata_request_timeout is None + else metadata_request_timeout + ) self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout self.schema_event_refresh_window = schema_event_refresh_window @@ -1457,14 +1677,19 @@ def __init__(self, if self.metrics_enabled: from cassandra.metrics import Metrics + self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection( - self, self.control_connection_timeout, - self.schema_event_refresh_window, self.topology_event_refresh_window, + self, + self.control_connection_timeout, + self.schema_event_refresh_window, + self.topology_event_refresh_window, self.status_event_refresh_window, - schema_metadata_enabled, token_metadata_enabled, - schema_meta_page_size=schema_metadata_page_size) + schema_metadata_enabled, + token_metadata_enabled, + schema_meta_page_size=schema_metadata_page_size, + ) if client_id is None: self.client_id = uuid.uuid4() @@ -1478,20 +1703,32 @@ def _resolve_hostnames(self): for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]: raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, self.port)) - self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] - self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port) - for ep in self.endpoints_resolved} + self.endpoints_resolved = [ + cp for cp in self.contact_points if isinstance(cp, EndPoint) + ] + self._endpoint_map_for_insights = { + repr(ep): "{ip}:{port}".format(ip=ep.address, port=ep.port) + for ep in self.endpoints_resolved + } strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points) - self.endpoints_resolved.extend(list(chain( - *[ - [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] - for xs in strs_resolved_map.values() if xs is not None - ] - ))) + self.endpoints_resolved.extend( + list( + chain( + *[ + [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] + for xs in strs_resolved_map.values() + if xs is not None + ] + ) + ) + ) self._endpoint_map_for_insights.update( - {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value] - for key, value in strs_resolved_map.items() if value is not None} + { + key: ["{ip}:{port}".format(ip=ip, port=port) for ip, port in value] + for key, value in strs_resolved_map.items() + if value is not None + } ) if self.contact_points and (not self.endpoints_resolved): @@ -1514,6 +1751,7 @@ def _create_thread_pool_executor(self, **kwargs): if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: try: from cassandra.io.eventletreactor import EventletConnection + is_eventlet = issubclass(self.connection_class, EventletConnection) except: # Eventlet is not available or can't be detected @@ -1522,15 +1760,19 @@ def _create_thread_pool_executor(self, **kwargs): if is_eventlet: try: from futurist import GreenThreadPoolExecutor + tpe_class = GreenThreadPoolExecutor except ImportError: # futurist is not available raise ImportError( - ("Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " - "to hang indefinitely. If you want to use the Eventlet reactor, you " - "need to install the `futurist` package to allow the driver to use " - "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " - "for more details.")) + ( + "Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " + "to hang indefinitely. If you want to use the Eventlet reactor, you " + "need to install the `futurist` package to allow the driver to use " + "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " + "for more details." + ) + ) return tpe_class(**kwargs) @@ -1583,9 +1825,14 @@ def __init__(self, street, zipcode): """ if self.protocol_version < 3: - log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " - "CQL encoding for simple statements will still work, but named tuples will " - "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) + log.warning( + "User Type serialization is only supported in native protocol version 3+ (%d in use). " + "CQL encoding for simple statements will still work, but named tuples will " + "be returned when reading type %s.%s.", + self.protocol_version, + keyspace, + user_type, + ) self._user_types[keyspace][user_type] = klass for session in tuple(self.sessions): @@ -1609,23 +1856,29 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5): if not isinstance(profile, ExecutionProfile): raise TypeError("profile must be an instance of ExecutionProfile") if self._config_mode == _ConfigMode.LEGACY: - raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.") + raise ValueError( + "Cannot add execution profiles when legacy parameters are set explicitly." + ) if name in self.profile_manager.profiles: raise ValueError("Profile {} already exists".format(name)) contact_points_but_no_lbp = ( - self._contact_points_explicit and not - profile._load_balancing_policy_explicit) + self._contact_points_explicit + and not profile._load_balancing_policy_explicit + ) if contact_points_but_no_lbp: log.warning( - 'Tried to add an ExecutionProfile with name {name}. ' - '{self} was explicitly configured with contact_points, but ' - '{ep} was not explicitly configured with a ' - 'load_balancing_policy. In the next major version, trying to ' - 'add an ExecutionProfile without an explicitly configured LBP ' - 'to a cluster with explicitly configured contact_points will ' - 'raise an exception; please specify a load-balancing policy ' - 'in the ExecutionProfile.' - ''.format(name=_execution_profile_to_string(name), self=self, ep=profile)) + "Tried to add an ExecutionProfile with name {name}. " + "{self} was explicitly configured with contact_points, but " + "{ep} was not explicitly configured with a " + "load_balancing_policy. In the next major version, trying to " + "add an ExecutionProfile without an explicitly configured LBP " + "to a cluster with explicitly configured contact_points will " + "raise an exception; please specify a load-balancing policy " + "in the ExecutionProfile." + "".format( + name=_execution_profile_to_string(name), self=self, ep=profile + ) + ) self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) @@ -1638,54 +1891,77 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5): futures.update(session.update_created_pools()) _, not_done = wait_futures(futures, pool_wait_timeout) if not_done: - raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") + raise OperationTimedOut( + "Failed to create all new connection pools in the %ss timeout." + ) - def connection_factory(self, endpoint, host_conn = None, *args, **kwargs): + def connection_factory(self, endpoint, host_conn=None, *args, **kwargs): """ Called to create a new connection with proper configuration. Intended for internal use only. """ kwargs = self._make_connection_kwargs(endpoint, kwargs) - return self.connection_class.factory(endpoint, self.connect_timeout, host_conn, *args, **kwargs) + return self.connection_class.factory( + endpoint, self.connect_timeout, host_conn, *args, **kwargs + ) def _make_connection_factory(self, host, *args, **kwargs): kwargs = self._make_connection_kwargs(host.endpoint, kwargs) - return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) + return partial( + self.connection_class.factory, + host.endpoint, + self.connect_timeout, + *args, + **kwargs, + ) def _make_connection_kwargs(self, endpoint, kwargs_dict): if self._auth_provider_callable: - kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) - - kwargs_dict.setdefault('port', self.port) - kwargs_dict.setdefault('compression', self.compression) - kwargs_dict.setdefault('sockopts', self.sockopts) - kwargs_dict.setdefault('ssl_options', self.ssl_options) - kwargs_dict.setdefault('ssl_context', self.ssl_context) - kwargs_dict.setdefault('cql_version', self.cql_version) - kwargs_dict.setdefault('protocol_version', self.protocol_version) - kwargs_dict.setdefault('user_type_map', self._user_types) - kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) - kwargs_dict.setdefault('no_compact', self.no_compact) - kwargs_dict.setdefault('application_info', self.application_info) + kwargs_dict.setdefault( + "authenticator", self._auth_provider_callable(endpoint.address) + ) + + kwargs_dict.setdefault("port", self.port) + kwargs_dict.setdefault("compression", self.compression) + kwargs_dict.setdefault("sockopts", self.sockopts) + kwargs_dict.setdefault("ssl_options", self.ssl_options) + kwargs_dict.setdefault("ssl_context", self.ssl_context) + kwargs_dict.setdefault("cql_version", self.cql_version) + kwargs_dict.setdefault("protocol_version", self.protocol_version) + kwargs_dict.setdefault("user_type_map", self._user_types) + kwargs_dict.setdefault( + "allow_beta_protocol_version", self.allow_beta_protocol_version + ) + kwargs_dict.setdefault("no_compact", self.no_compact) + kwargs_dict.setdefault("application_info", self.application_info) return kwargs_dict def protocol_downgrade(self, host_endpoint, previous_version): if self._protocol_version_explicit: - raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) + raise DriverException( + "ProtocolError returned from server while using explicitly set client protocol_version %d" + % (previous_version,) + ) new_version = ProtocolVersion.get_lower_supported(previous_version) if new_version < ProtocolVersion.MIN_SUPPORTED: raise DriverException( - "Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,)) + "Cannot downgrade protocol version below minimum supported version: %d" + % (ProtocolVersion.MIN_SUPPORTED,) + ) - log.warning("Downgrading core protocol version from %d to %d for %s. " - "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " - "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) + log.warning( + "Downgrading core protocol version from %d to %d for %s. " + "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " + "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", + self.protocol_version, + new_version, + host_endpoint, + ) self.protocol_version = new_version def _populate_hosts(self): - self.profile_manager.populate( - weakref.proxy(self), self.metadata.all_hosts()) + self.profile_manager.populate(weakref.proxy(self), self.metadata.all_hosts()) self.load_balancing_policy.populate( weakref.proxy(self), self.metadata.all_hosts() ) @@ -1706,8 +1982,11 @@ def connect(self, keyspace=None, wait_for_all_pools=False): raise DriverException("Cluster is already shut down") if not self._is_setup: - log.debug("Connecting to cluster, contact points: %s; protocol version: %s", - self.contact_points, self.protocol_version) + log.debug( + "Connecting to cluster, contact points: %s; protocol version: %s", + self.contact_points, + self.protocol_version, + ) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) @@ -1717,8 +1996,9 @@ def connect(self, keyspace=None, wait_for_all_pools=False): log.debug("Control connection created") except Exception: - log.exception("Control connection failed to connect, " - "shutting down Cluster:") + log.exception( + "Control connection failed to connect, shutting down Cluster:" + ) self.shutdown() raise @@ -1728,7 +2008,7 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self._idle_heartbeat = ConnectionHeartbeat( self.idle_heartbeat_interval, self.get_connection_holders, - timeout=self.idle_heartbeat_timeout + timeout=self.idle_heartbeat_timeout, ) self._is_setup = True @@ -1765,9 +2045,13 @@ def is_shard_aware(self): def shard_aware_stats(self): if self.is_shard_aware(): - return {str(pool.host.endpoint): {'shards_count': pool.host.sharding_info.shards_count, - 'connected': len(pool._connections.keys())} - for pool in self.get_all_pools()} + return { + str(pool.host.endpoint): { + "shards_count": pool.host.sharding_info.shards_count, + "connected": len(pool._connections.keys()), + } + for pool in self.get_all_pools() + } def shutdown(self): """ @@ -1840,12 +2124,16 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) try: # all futures have completed at this point for exc in [f for f in results if isinstance(f, Exception)]: - log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) + log.error( + "Unexpected failure while marking node %s up:", host, exc_info=exc + ) self._cleanup_failed_on_up_handling(host) return if not all(results): - log.debug("Connection pool could not be created, not marking node %s up", host) + log.debug( + "Connection pool could not be created, not marking node %s up", host + ) self._cleanup_failed_on_up_handling(host) return @@ -1872,7 +2160,9 @@ def on_up(self, host): log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: if host._currently_handling_node_up: - log.debug("Another thread is already handling up status of node %s", host) + log.debug( + "Another thread is already handling up status of node %s", host + ) return if host.is_up: @@ -1885,11 +2175,15 @@ def on_up(self, host): have_future = False futures = set() try: - log.info("Host %s may be up; will prepare queries and open connection pool", host) + log.info( + "Host %s may be up; will prepare queries and open connection pool", host + ) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: - log.debug("Now that host %s is up, cancelling the reconnection handler", host) + log.debug( + "Now that host %s is up, cancelling the reconnection handler", host + ) reconnector.cancel() if self.profile_manager.distance(host) != HostDistance.IGNORED: @@ -1908,7 +2202,13 @@ def on_up(self, host): log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] - callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback = partial( + self._on_up_future_completed, + host, + futures, + futures_results, + futures_lock, + ) for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: @@ -1946,9 +2246,16 @@ def _start_reconnector(self, host, is_host_addition): conn_factory = self._make_connection_factory(host) reconnector = _HostReconnectionHandler( - host, conn_factory, is_host_addition, self.on_add, self.on_up, - self.scheduler, schedule, host.get_and_set_reconnection_handler, - new_handler=None) + host, + conn_factory, + is_host_addition, + self.on_add, + self.on_up, + self.scheduler, + schedule, + host.get_and_set_reconnection_handler, + new_handler=None, + ) old_reconnector = host.get_and_set_reconnection_handler(reconnector) if old_reconnector: @@ -1982,18 +2289,23 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated - if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: + if ( + self._discount_down_events + and self.profile_manager.distance(host) != HostDistance.IGNORED + ): connected = False for session in tuple(self.sessions): pool_states = session.get_pool_state() pool_state = pool_states.get(host) if pool_state: - connected |= pool_state['open_count'] > 0 + connected |= pool_state["open_count"] > 0 if connected: return host.set_down() - if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + if ( + not was_up and not expect_host_to_be_down + ) or host.is_currently_reconnecting(): return log.warning("Host %s has been marked down", host) @@ -2014,8 +2326,11 @@ def on_add(self, host, refresh_nodes=True): log.debug("Done preparing queries for new host %r", host) if distance == HostDistance.IGNORED: - log.debug("Not adding connection pool for new host %r because the " - "load balancing policy has marked it as IGNORED", host) + log.debug( + "Not adding connection pool for new host %r because the " + "load balancing policy has marked it as IGNORED", + host, + ) self._finalize_add(host, set_up=False) return @@ -2035,14 +2350,20 @@ def future_completed(future): if futures: return - log.debug('All futures have completed for added host %s', host) + log.debug("All futures have completed for added host %s", host) for exc in [f for f in futures_results if isinstance(f, Exception)]: - log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + log.error( + "Unexpected failure while adding node %s, will not mark up:", + host, + exc_info=exc, + ) return if not all(futures_results): - log.warning("Connection pool could not be created, not marking node %s up", host) + log.warning( + "Connection pool could not be created, not marking node %s up", host + ) return self._finalize_add(host) @@ -2086,13 +2407,23 @@ def on_remove(self, host): if reconnection_handler: reconnection_handler.cancel() - def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): + def signal_connection_failure( + self, host, connection_exc, is_host_addition, expect_host_to_be_down=False + ): is_down = host.signal_connection_failure(connection_exc) if is_down: self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None): + def add_host( + self, + endpoint, + datacenter=None, + rack=None, + signal=True, + refresh_nodes=True, + host_id=None, + ): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. @@ -2102,8 +2433,18 @@ def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_no """ with self.metadata._hosts_lock: if endpoint in self.metadata._host_id_by_endpoint: - return self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]], False - host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id)) + return self.metadata._hosts[ + self.metadata._host_id_by_endpoint[endpoint] + ], False + host, new = self.metadata.add_or_return_host( + Host( + endpoint, + self.conviction_policy_factory, + datacenter, + rack, + host_id=host_id, + ) + ) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) @@ -2129,7 +2470,7 @@ def register_listener(self, listener): self._listeners.add(listener) def unregister_listener(self, listener): - """ Removes a registered listener. """ + """Removes a registered listener.""" with self._listener_lock: self._listeners.remove(listener) @@ -2151,9 +2492,13 @@ def _ensure_core_connections(self): def _validate_refresh_schema(keyspace, table, usertype, function, aggregate): if any((table, usertype, function, aggregate)): if not keyspace: - raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}") + raise ValueError( + "keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}" + ) if sum(1 for e in (table, usertype, function) if e) > 1: - raise ValueError("{table, usertype, function, aggregate} are mutually exclusive") + raise ValueError( + "{table, usertype, function, aggregate} are mutually exclusive" + ) @staticmethod def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate): @@ -2190,8 +2535,12 @@ def refresh_schema_metadata(self, max_schema_agreement_wait=None): An Exception is raised if schema refresh fails for any reason. """ - if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("Schema metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + schema_agreement_wait=max_schema_agreement_wait, force=True + ): + raise DriverException( + "Schema metadata was not refreshed. See log for details." + ) def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ @@ -2200,9 +2549,15 @@ def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("Keyspace metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.KEYSPACE, + keyspace=keyspace, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "Keyspace metadata was not refreshed. See log for details." + ) def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ @@ -2211,31 +2566,58 @@ def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("Table metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.TABLE, + keyspace=keyspace, + table=table, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "Table metadata was not refreshed. See log for details." + ) - def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): + def refresh_materialized_view_metadata( + self, keyspace, view, max_schema_agreement_wait=None + ): """ Synchronously refresh materialized view metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("View metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.TABLE, + keyspace=keyspace, + table=view, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "View metadata was not refreshed. See log for details." + ) - def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): + def refresh_user_type_metadata( + self, keyspace, user_type, max_schema_agreement_wait=None + ): """ Synchronously refresh user defined type metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("User Type metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.TYPE, + keyspace=keyspace, + type=user_type, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "User Type metadata was not refreshed. See log for details." + ) - def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): + def refresh_user_function_metadata( + self, keyspace, function, max_schema_agreement_wait=None + ): """ Synchronously refresh user defined function metadata. @@ -2243,11 +2625,20 @@ def refresh_user_function_metadata(self, keyspace, function, max_schema_agreemen See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("User Function metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.FUNCTION, + keyspace=keyspace, + function=function, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "User Function metadata was not refreshed. See log for details." + ) - def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): + def refresh_user_aggregate_metadata( + self, keyspace, aggregate, max_schema_agreement_wait=None + ): """ Synchronously refresh user defined aggregate metadata. @@ -2255,9 +2646,16 @@ def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreem See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, - schema_agreement_wait=max_schema_agreement_wait, force=True): - raise DriverException("User Aggregate metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema( + target_type=SchemaTargetType.AGGREGATE, + keyspace=keyspace, + aggregate=aggregate, + schema_agreement_wait=max_schema_agreement_wait, + force=True, + ): + raise DriverException( + "User Aggregate metadata was not refreshed. See log for details." + ) def refresh_nodes(self, force_token_rebuild=False): """ @@ -2267,7 +2665,9 @@ def refresh_nodes(self, force_token_rebuild=False): An Exception is raised if node refresh fails for any reason. """ - if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): + if not self.control_connection.refresh_node_list_and_token_map( + force_token_rebuild + ): raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): @@ -2282,23 +2682,35 @@ def set_meta_refresh_enabled(self, enabled): Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ - warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " - "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning) + warn( + "Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " + "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", + DeprecationWarning, + ) self.schema_metadata_enabled = enabled self.token_metadata_enabled = enabled @classmethod def _send_chunks(cls, connection, host, chunks, set_keyspace=False): for ks_chunk in chunks: - messages = [PrepareMessage(query=s.query_string, - keyspace=s.keyspace if set_keyspace else None) - for s in ks_chunk] + messages = [ + PrepareMessage( + query=s.query_string, keyspace=s.keyspace if set_keyspace else None + ) + for s in ks_chunk + ] # TODO: make this timeout configurable somehow? - responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) + responses = connection.wait_for_responses( + *messages, timeout=5.0, fail_on_error=False + ) for success, response in responses: if not success: - log.debug("Got unexpected response when preparing " - "statement on host %s: %r", host, response) + log.debug( + "Got unexpected response when preparing " + "statement on host %s: %r", + host, + response, + ) def _prepare_all_queries(self, host): if not self._prepared_statements or not self.reprepare_on_up: @@ -2313,10 +2725,12 @@ def _prepare_all_queries(self, host): # V5 protocol and higher, no need to set the keyspace chunks = [] for i in range(0, len(statements), 10): - chunks.append(statements[i:i + 10]) + chunks.append(statements[i : i + 10]) self._send_chunks(connection, host, chunks, True) else: - for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): + for keyspace, ks_statements in groupby( + statements, lambda s: s.keyspace + ): if keyspace is not None: connection.set_keyspace_blocking(keyspace) @@ -2324,14 +2738,22 @@ def _prepare_all_queries(self, host): ks_statements = list(ks_statements) chunks = [] for i in range(0, len(ks_statements), 10): - chunks.append(ks_statements[i:i + 10]) + chunks.append(ks_statements[i : i + 10]) self._send_chunks(connection, host, chunks) - log.debug("Done preparing all known prepared statements against host %s", host) + log.debug( + "Done preparing all known prepared statements against host %s", host + ) except OperationTimedOut as timeout: - log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) + log.warning( + "Timed out trying to prepare all statements on host %s: %s", + host, + timeout, + ) except (ConnectionException, socket.error) as exc: - log.warning("Error trying to prepare all statements on host %s: %r", host, exc) + log.warning( + "Error trying to prepare all statements on host %s: %r", host, exc + ) except Exception: log.exception("Error trying to prepare all statements on host %s", host) finally: @@ -2342,6 +2764,7 @@ def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement + class Session(object): """ A collection of connection pools for each host in the cluster. @@ -2368,6 +2791,7 @@ class Session(object): _monitor_reporter = None _row_factory = staticmethod(named_tuple_factory) + @property def row_factory(self): """ @@ -2385,7 +2809,7 @@ def row_factory(self): @row_factory.setter def row_factory(self, rf): - self._validate_set_legacy_config('row_factory', rf) + self._validate_set_legacy_config("row_factory", rf) _default_timeout = 10.0 @@ -2407,7 +2831,7 @@ def default_timeout(self): @default_timeout.setter def default_timeout(self, timeout): - self._validate_set_legacy_config('default_timeout', timeout) + self._validate_set_legacy_config("default_timeout", timeout) _default_consistency_level = ConsistencyLevel.LOCAL_ONE @@ -2432,10 +2856,12 @@ def default_consistency_level(self, cl): """ *Deprecated:* use execution profiles instead """ - warn("Setting the consistency level at the session level will be removed in 4.0. Consider using " - "execution profiles and setting the desired consistency level to the EXEC_PROFILE_DEFAULT profile." - , DeprecationWarning) - self._validate_set_legacy_config('default_consistency_level', cl) + warn( + "Setting the consistency level at the session level will be removed in 4.0. Consider using " + "execution profiles and setting the desired consistency level to the EXEC_PROFILE_DEFAULT profile.", + DeprecationWarning, + ) + self._validate_set_legacy_config("default_consistency_level", cl) _default_serial_consistency_level = None @@ -2452,13 +2878,14 @@ def default_serial_consistency_level(self): @default_serial_consistency_level.setter def default_serial_consistency_level(self, cl): - if (cl is not None and - not ConsistencyLevel.is_serial(cl)): - raise ValueError("default_serial_consistency_level must be either " - "ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL.") + if cl is not None and not ConsistencyLevel.is_serial(cl): + raise ValueError( + "default_serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL." + ) - self._validate_set_legacy_config('default_serial_consistency_level', cl) + self._validate_set_legacy_config("default_serial_consistency_level", cl) max_trace_wait = 2.0 """ @@ -2584,7 +3011,9 @@ def __init__(self, cluster, hosts, keyspace=None): if future: self._initial_connect_futures.add(future) - futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + futures = wait_futures( + self._initial_connect_futures, return_when=FIRST_COMPLETED + ) while futures.not_done and not any(f.result() for f in futures.done): futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) @@ -2601,15 +3030,19 @@ def __init__(self, cluster, hosts, keyspace=None): self.client_protocol_handler = type( str(self.session_id) + "-ProtocolHandler", (ProtocolHandler,), - {"column_encryption_policy": self.cluster.column_encryption_policy}) + {"column_encryption_policy": self.cluster.column_encryption_policy}, + ) except AttributeError: log.info("Unable to set column encryption policy for session") raise Exception( - "column_encryption_policy is temporary disabled, until https://github.com/scylladb/python-driver/issues/365 is sorted out") + "column_encryption_policy is temporary disabled, until https://github.com/scylladb/python-driver/issues/365 is sorted out" + ) if self.cluster.monitor_reporting_enabled: cc_host = self.cluster.get_control_connection_host() - valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version)) + valid_insights_version = cc_host and version_supports_insights( + cc_host.dse_version + ) if valid_insights_version: self._monitor_reporter = MonitorReporter( interval_sec=self.cluster.monitor_reporting_interval, @@ -2617,16 +3050,32 @@ def __init__(self, cluster, hosts, keyspace=None): ) else: if cc_host: - log.debug('Not starting MonitorReporter thread for Insights; ' - 'not supported by server version {v} on ' - 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host)) + log.debug( + "Not starting MonitorReporter thread for Insights; " + "not supported by server version {v} on " + "ControlConnection host {c}".format( + v=cc_host.release_version, c=cc_host + ) + ) - log.debug('Started Session with client_id {} and session_id {}'.format(self.cluster.client_id, - self.session_id)) + log.debug( + "Started Session with client_id {} and session_id {}".format( + self.cluster.client_id, self.session_id + ) + ) - def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, - custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, - paging_state=None, host=None, execute_as=None): + def execute( + self, + query, + parameters=None, + timeout=_NOT_SET, + trace=False, + custom_payload=None, + execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, + host=None, + execute_as=None, + ): """ Execute the given query and synchronously wait for the response. @@ -2667,11 +3116,30 @@ def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, on a DSE cluster. """ - return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host, execute_as).result() - - def execute_async(self, query, parameters=None, trace=False, custom_payload=None, - timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, - paging_state=None, host=None, execute_as=None): + return self.execute_async( + query, + parameters, + trace, + custom_payload, + timeout, + execution_profile, + paging_state, + host, + execute_as, + ).result() + + def execute_async( + self, + query, + parameters=None, + trace=False, + custom_payload=None, + timeout=_NOT_SET, + execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, + host=None, + execute_as=None, + ): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response @@ -2711,14 +3179,28 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None custom_payload[_proxy_execute_key] = execute_as.encode() future = self._create_response_future( - query, parameters, trace, custom_payload, timeout, - execution_profile, paging_state, host) + query, + parameters, + trace, + custom_payload, + timeout, + execution_profile, + paging_state, + host, + ) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future - def execute_graph(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + def execute_graph( + self, + query, + parameters=None, + trace=False, + execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, + execute_as=None, + ): """ Executes a Gremlin query string or GraphStatement synchronously, and returns a ResultSet from this execution. @@ -2730,18 +3212,31 @@ def execute_graph(self, query, parameters=None, trace=False, execution_profile=E `execute_as` the user that will be used on the server to execute the request. """ - return self.execute_graph_async(query, parameters, trace, execution_profile, execute_as).result() + return self.execute_graph_async( + query, parameters, trace, execution_profile, execute_as + ).result() - def execute_graph_async(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + def execute_graph_async( + self, + query, + parameters=None, + trace=False, + execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, + execute_as=None, + ): """ Execute the graph query and return a :class:`ResponseFuture` object which callbacks may be attached to for asynchronous response delivery. You may also call ``ResponseFuture.result()`` to synchronously block for results at any time. """ if self.cluster._config_mode is _ConfigMode.LEGACY: - raise ValueError(("Cannot execute graph queries using Cluster legacy parameters. " - "Consider using Execution profiles: " - "https://docs.datastax.com/en/developer/python-driver/latest/execution_profiles/#execution-profiles")) + raise ValueError( + ( + "Cannot execute graph queries using Cluster legacy parameters. " + "Consider using Execution profiles: " + "https://docs.datastax.com/en/developer/python-driver/latest/execution_profiles/#execution-profiles" + ) + ) if not isinstance(query, GraphStatement): query = SimpleGraphStatement(query) @@ -2749,16 +3244,19 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro # Clone and look up instance here so we can resolve and apply the extended attributes execution_profile = self.execution_profile_clone_update(execution_profile) - if not hasattr(execution_profile, 'graph_options'): + if not hasattr(execution_profile, "graph_options"): raise ValueError( - "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options") + "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options" + ) self._resolve_execution_profile_options(execution_profile) # make sure the graphson context row factory is binded to this cluster try: if issubclass(execution_profile.row_factory, _GraphSONContextRowFactory): - execution_profile.row_factory = execution_profile.row_factory(self.cluster) + execution_profile.row_factory = execution_profile.row_factory( + self.cluster + ) except TypeError: # issubclass might fail if arg1 is an instance pass @@ -2768,21 +3266,32 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro graph_parameters = None if parameters: - graph_parameters = self._transform_params(parameters, graph_options=execution_profile.graph_options) + graph_parameters = self._transform_params( + parameters, graph_options=execution_profile.graph_options + ) custom_payload = execution_profile.graph_options.get_options_map() if execute_as: custom_payload[_proxy_execute_key] = execute_as.encode() - custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000)) + custom_payload[_request_timeout_key] = int64_pack( + int(execution_profile.request_timeout * 1000) + ) - future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, - timeout=_NOT_SET, execution_profile=execution_profile) + future = self._create_response_future( + query, + parameters=None, + trace=trace, + custom_payload=custom_payload, + timeout=_NOT_SET, + execution_profile=execution_profile, + ) future.message.query_params = graph_parameters future._protocol_handler = self.client_protocol_handler - if execution_profile.graph_options.is_analytics_source and \ - isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy): + if execution_profile.graph_options.is_analytics_source and isinstance( + execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy + ): self._target_analytics_master(future) else: future.send_request() @@ -2808,8 +3317,10 @@ def _resolve_execution_profile_options(self, execution_profile): - Default to graph_object_row_factory. - If `graph_options.graph_name` is specified and is a Core graph, set graph_graphson3_row_factory. """ - if execution_profile.graph_options.graph_protocol is not None and \ - execution_profile.row_factory is not None: + if ( + execution_profile.graph_options.graph_protocol is not None + and execution_profile.row_factory is not None + ): return graph_options = execution_profile.graph_options @@ -2817,10 +3328,10 @@ def _resolve_execution_profile_options(self, execution_profile): is_core_graph = False if graph_options.graph_name: # graph_options.graph_name is bytes ... - name = graph_options.graph_name.decode('utf-8') + name = graph_options.graph_name.decode("utf-8") if name in self.cluster.metadata.keyspaces: ks_metadata = self.cluster.metadata.keyspaces[name] - if ks_metadata.graph_engine == 'Core': + if ks_metadata.graph_engine == "Core": is_core_graph = True if is_core_graph: @@ -2843,7 +3354,9 @@ def _resolve_execution_profile_options(self, execution_profile): def _transform_params(self, parameters, graph_options): if not isinstance(parameters, dict): - raise ValueError('The parameters must be a dictionary. Unnamed parameters are not allowed.') + raise ValueError( + "The parameters must be a dictionary. Unnamed parameters are not allowed." + ) # Serialize python types to graphson serializer = GraphSON1Serializer @@ -2852,45 +3365,68 @@ def _transform_params(self, parameters, graph_options): elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: # only required for core graphs context = { - 'cluster': self.cluster, - 'graph_name': graph_options.graph_name.decode('utf-8') if graph_options.graph_name else None + "cluster": self.cluster, + "graph_name": graph_options.graph_name.decode("utf-8") + if graph_options.graph_name + else None, } serializer = GraphSON3Serializer(context) serialized_parameters = serializer.serialize(parameters) - return [json.dumps(serialized_parameters).encode('utf-8')] + return [json.dumps(serialized_parameters).encode("utf-8")] def _target_analytics_master(self, future): future._start_timer() - master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()", - parameters=None, trace=False, - custom_payload=None, timeout=future.timeout) + master_query_future = self._create_response_future( + "CALL DseClientTool.getAnalyticsGraphServer()", + parameters=None, + trace=False, + custom_payload=None, + timeout=future.timeout, + ) master_query_future.row_factory = tuple_factory master_query_future.send_request() cb = self._on_analytics_master_result args = (master_query_future, future) - master_query_future.add_callbacks(callback=cb, callback_args=args, errback=cb, errback_args=args) + master_query_future.add_callbacks( + callback=cb, callback_args=args, errback=cb, errback_args=args + ) def _on_analytics_master_result(self, response, master_future, query_future): try: row = master_future.result()[0] - addr = row[0]['location'] - delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided + addr = row[0]["location"] + delimiter_index = addr.rfind( + ":" + ) # assumes : - not robust, but that's what is being provided if delimiter_index > 0: addr = addr[:delimiter_index] targeted_query = HostTargetingStatement(query_future.query, addr) - query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + query_future.query_plan = query_future._load_balancer.make_query_plan( + self.keyspace, targeted_query + ) except Exception: - log.debug("Failed querying analytics master (request might not be routed optimally). " - "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) + log.debug( + "Failed querying analytics master (request might not be routed optimally). " + "Make sure the session is connecting to a graph analytics datacenter.", + exc_info=True, + ) self.submit(query_future.send_request) - def _create_response_future(self, query, parameters, trace, custom_payload, - timeout, execution_profile=EXEC_PROFILE_DEFAULT, - paging_state=None, host=None): - """ Returns the ResponseFuture before calling send_request() on it """ + def _create_response_future( + self, + query, + parameters, + trace, + custom_payload, + timeout, + execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, + host=None, + ): + """Returns the ResponseFuture before calling send_request() on it""" prepared_statement = None @@ -2901,13 +3437,23 @@ def _create_response_future(self, query, parameters, trace, custom_payload, if self.cluster._config_mode == _ConfigMode.LEGACY: if execution_profile is not EXEC_PROFILE_DEFAULT: - raise ValueError("Cannot specify execution_profile while using legacy parameters.") + raise ValueError( + "Cannot specify execution_profile while using legacy parameters." + ) if timeout is _NOT_SET: timeout = self.default_timeout - cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level - serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level + cl = ( + query.consistency_level + if query.consistency_level is not None + else self.default_consistency_level + ) + serial_cl = ( + query.serial_consistency_level + if query.serial_consistency_level is not None + else self.default_serial_consistency_level + ) retry_policy = query.retry_policy or self.cluster.default_retry_policy row_factory = self.row_factory @@ -2920,8 +3466,16 @@ def _create_response_future(self, query, parameters, trace, custom_payload, if timeout is _NOT_SET: timeout = execution_profile.request_timeout - cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level - serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level + cl = ( + query.consistency_level + if query.consistency_level is not None + else execution_profile.consistency_level + ) + serial_cl = ( + query.serial_consistency_level + if query.serial_consistency_level is not None + else execution_profile.serial_consistency_level + ) continuous_paging_options = execution_profile.continuous_paging_options retry_policy = query.retry_policy or execution_profile.retry_policy @@ -2941,48 +3495,94 @@ def _create_response_future(self, query, parameters, trace, custom_payload, if isinstance(query, SimpleStatement): query_string = query.query_string - statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None + statement_keyspace = ( + query.keyspace + if ProtocolVersion.uses_keyspace_flag(self._protocol_version) + else None + ) if parameters: query_string = bind_params(query_string, parameters, self.encoder) message = QueryMessage( - query_string, cl, serial_cl, - fetch_size, paging_state, timestamp, - continuous_paging_options, statement_keyspace) + query_string, + cl, + serial_cl, + fetch_size, + paging_state, + timestamp, + continuous_paging_options, + statement_keyspace, + ) elif isinstance(query, BoundStatement): prepared_statement = query.prepared_statement message = ExecuteMessage( - prepared_statement.query_id, query.values, cl, - serial_cl, fetch_size, paging_state, timestamp, + prepared_statement.query_id, + query.values, + cl, + serial_cl, + fetch_size, + paging_state, + timestamp, skip_meta=bool(prepared_statement.result_metadata), continuous_paging_options=continuous_paging_options, - result_metadata_id=prepared_statement.result_metadata_id) + result_metadata_id=prepared_statement.result_metadata_id, + ) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( "BatchStatement execution is only supported with protocol version " "2 or higher (supported in Cassandra 2.0 and higher). Consider " - "setting Cluster.protocol_version to 2 to support this operation.") - statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None + "setting Cluster.protocol_version to 2 to support this operation." + ) + statement_keyspace = ( + query.keyspace + if ProtocolVersion.uses_keyspace_flag(self._protocol_version) + else None + ) message = BatchMessage( - query.batch_type, query._statements_and_parameters, cl, - serial_cl, timestamp, statement_keyspace) + query.batch_type, + query._statements_and_parameters, + cl, + serial_cl, + timestamp, + statement_keyspace, + ) elif isinstance(query, GraphStatement): # the statement_keyspace is not aplicable to GraphStatement - message = QueryMessage(query.query, cl, serial_cl, fetch_size, - paging_state, timestamp, - continuous_paging_options) + message = QueryMessage( + query.query, + cl, + serial_cl, + fetch_size, + paging_state, + timestamp, + continuous_paging_options, + ) message.tracing = trace message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version - spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None + spec_exec_plan = ( + spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) + if query.is_idempotent and spec_exec_policy + else None + ) return ResponseFuture( - self, message, query, timeout, metrics=self._metrics, - prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, - load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan, - continuous_paging_state=None, host=host) + self, + message, + query, + timeout, + metrics=self._metrics, + prepared_statement=prepared_statement, + retry_policy=retry_policy, + row_factory=row_factory, + load_balancer=load_balancing_policy, + start_time=start_time, + speculative_execution_plan=spec_exec_plan, + continuous_paging_state=None, + host=host, + ) def get_execution_profile(self, name): """ @@ -2995,11 +3595,15 @@ def get_execution_profile(self, name): return profiles[name] except KeyError: eps = [_execution_profile_to_string(ep) for ep in profiles.keys()] - raise ValueError("Invalid execution_profile: %s; valid profiles are: %s." % ( - _execution_profile_to_string(name), ', '.join(eps))) + raise ValueError( + "Invalid execution_profile: %s; valid profiles are: %s." + % (_execution_profile_to_string(name), ", ".join(eps)) + ) def _maybe_get_execution_profile(self, ep): - return ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep) + return ( + ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep) + ) def execution_profile_clone_update(self, ep, **kwargs): """ @@ -3099,8 +3703,18 @@ def prepare(self, query, custom_payload=None, keyspace=None): prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( - response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, - self._protocol_version, response.column_metadata, response.result_metadata_id, response.is_lwt, self.cluster.column_encryption_policy) + response.query_id, + response.bind_metadata, + response.pk_indexes, + self.cluster.metadata, + query, + prepared_keyspace, + self._protocol_version, + response.column_metadata, + response.result_metadata_id, + response.is_lwt, + self.cluster.column_encryption_policy, + ) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(response.query_id, prepared_statement) @@ -3108,7 +3722,9 @@ def prepare(self, query, custom_payload=None, keyspace=None): if self.cluster.prepare_on_all_hosts: host = future._current_host try: - self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) + self.prepare_on_all_hosts( + prepared_statement.query_string, host, prepared_keyspace + ) except Exception: log.exception("Error preparing query on all hosts:") @@ -3122,8 +3738,12 @@ def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): futures = [] for host in tuple(self._pools.keys()): if host != excluded_host and host.is_up: - future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), - None, self.default_timeout) + future = ResponseFuture( + self, + PrepareMessage(query=query, keyspace=keyspace), + None, + self.default_timeout, + ) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared @@ -3136,8 +3756,11 @@ def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): if request_id is None: # the error has already been logged by ResponsFuture - log.debug("Failed to prepare query for host %s: %r", - host, future._errors.get(host)) + log.debug( + "Failed to prepare query for host %s: %r", + host, + future._errors.get(host), + ) continue futures.append((host, future)) @@ -3197,18 +3820,22 @@ def add_or_renew_pool(self, host, is_host_addition): def run_add_or_renew_pool(): try: - new_pool = HostConnection(host, distance, self) + new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: - log.warning("Failed to create connection pool for new host %s:", - host, exc_info=conn_exc) + log.warning( + "Failed to create connection pool for new host %s:", + host, + exc_info=conn_exc, + ) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + host, conn_exc, is_host_addition, expect_host_to_be_down=True + ) return False previous = self._pools.get(host) @@ -3225,7 +3852,10 @@ def callback(pool, errors): new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: - log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) + log.warning( + "Failed setting keyspace for pool after keyspace changed during connect: %s", + errors_returned, + ) self.cluster.on_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() @@ -3292,7 +3922,7 @@ def on_down(self, host): future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): - """ Internal """ + """Internal""" self.on_down(host) def set_keyspace(self, keyspace): @@ -3300,7 +3930,7 @@ def set_keyspace(self, keyspace): Set the default keyspace for all queries made through this Session. This operation blocks until complete. """ - self.execute('USE %s' % (protect_name(keyspace),)) + self.execute("USE %s" % (protect_name(keyspace),)) def _set_keyspace_for_all_pools(self, keyspace, callback): """ @@ -3339,39 +3969,51 @@ def user_type_registered(self, keyspace, user_type, klass): ks_meta = self.cluster.metadata.keyspaces[keyspace] except KeyError: raise UserTypeDoesNotExist( - 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) + "Keyspace %s does not exist or has not been discovered by the driver" + % (keyspace,) + ) try: type_meta = ks_meta.user_types[user_type] except KeyError: raise UserTypeDoesNotExist( - 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) + "User type %s does not exist in keyspace %s" % (user_type, keyspace) + ) field_names = type_meta.field_names def encode(val): - return '{ %s }' % ' , '.join('%s : %s' % ( - field_name, - self.encoder.cql_encode_all_types(getattr(val, field_name, None)) - ) for field_name in field_names) + return "{ %s }" % " , ".join( + "%s : %s" + % ( + field_name, + self.encoder.cql_encode_all_types(getattr(val, field_name, None)), + ) + for field_name in field_names + ) self.encoder.mapping[klass] = encode def submit(self, fn, *args, **kwargs): - """ Internal """ + """Internal""" if not self.is_shutdown: return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): - return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + return dict( + (host, pool.get_state()) for host, pool in tuple(self._pools.items()) + ) def get_pools(self): return self._pools.values() def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: - raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) - setattr(self, '_' + attr_name, value) + raise ValueError( + "Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." + % (attr_name,) + ) + setattr(self, "_" + attr_name, value) self.cluster._config_mode = _ConfigMode.LEGACY @@ -3381,6 +4023,7 @@ class UserTypeDoesNotExist(Exception): .. versionadded:: 2.1.0 """ + pass @@ -3440,9 +4083,13 @@ class ControlConnection(object): _SELECT_LOCAL = "SELECT broadcast_address, cluster_name, data_center, host_id, listen_address, partitioner, rack, release_version, rpc_address, schema_version, tokens FROM system.local WHERE key='local'" _SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version, rpc_address FROM system.local WHERE key='local'" # Used only when token_metadata_enabled is set to False - _SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'" + _SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = ( + "SELECT rpc_address FROM system.local WHERE key='local'" + ) - _SELECT_SCHEMA_PEERS_TEMPLATE = "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers" + _SELECT_SCHEMA_PEERS_TEMPLATE = ( + "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers" + ) _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" _SELECT_PEERS_V2 = "SELECT * FROM system.peers_v2" @@ -3453,6 +4100,7 @@ class ControlConnection(object): class PeersQueryType(object): """internal Enum for _peers_query""" + PEERS = 0 PEERS_SCHEMA = 1 @@ -3475,13 +4123,17 @@ class PeersQueryType(object): # for testing purposes _time = time - def __init__(self, cluster, timeout, - schema_event_refresh_window, - topology_event_refresh_window, - status_event_refresh_window, - schema_meta_enabled=True, - token_meta_enabled=True, - schema_meta_page_size=1000): + def __init__( + self, + cluster, + timeout, + schema_event_refresh_window, + topology_event_refresh_window, + status_event_refresh_window, + schema_meta_enabled=True, + token_meta_enabled=True, + schema_meta_page_size=1000, + ): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) @@ -3510,7 +4162,9 @@ def connect(self): self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) - self._cluster.metadata.dbaas = self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE + self._cluster.metadata.dbaas = ( + self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE + ) def _set_new_connection(self, conn): """ @@ -3521,23 +4175,39 @@ def _set_new_connection(self, conn): self._connection = conn if old: - log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) + log.debug( + "[control connection] Closing old connection %r, replacing with %r", + old, + conn, + ) old.close() def _try_connect_to_hosts(self): errors = {} - lbp = self._cluster.load_balancing_policy \ - if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy + lbp = ( + self._cluster.load_balancing_policy + if self._cluster._config_mode == _ConfigMode.LEGACY + else self._cluster._default_load_balancing_policy + ) - for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved): + for endpoint in chain( + (host.endpoint for host in lbp.make_query_plan()), + self._cluster.endpoints_resolved, + ): try: return (self._try_connect(endpoint), None) except Exception as exc: errors[str(endpoint)] = exc - log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True) + log.warning( + "[control connection] Error connecting to %s:", + endpoint, + exc_info=True, + ) if self._is_shutdown: - raise DriverException("[control connection] Reconnection in progress during shutdown") + raise DriverException( + "[control connection] Reconnection in progress during shutdown" + ) return (None, errors) @@ -3574,7 +4244,9 @@ def _try_connect(self, endpoint): while True: try: - connection = self._cluster.connection_factory(endpoint, is_control_connection=True) + connection = self._cluster.connection_factory( + endpoint, is_control_connection=True + ) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") @@ -3585,14 +4257,21 @@ def _try_connect(self, endpoint): # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver # protocol version. If the protocol version was not explicitly specified, # and that the server raises a beta protocol error, we should downgrade. - if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: - self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version) + if ( + not self._cluster._protocol_version_explicit + and e.is_beta_protocol_error + ): + self._cluster.protocol_downgrade( + endpoint, self._cluster.protocol_version + ) else: raise - log.debug("[control connection] Established new connection %r, " - "registering watchers and refreshing schema and topology", - connection) + log.debug( + "[control connection] Established new connection %r, " + "registering watchers and refreshing schema and topology", + connection, + ) # Indirect way to determine if conencted to a ScyllaDB cluster, which does not support peers_v2 # If sharding information is available, it's a ScyllaDB cluster, so do not use peers_v2 table. @@ -3601,8 +4280,12 @@ def _try_connect(self, endpoint): # Only ScyllaDB supports "USING TIMEOUT" # Sharding information signals it is ScyllaDB - self._metadata_request_timeout = None if connection.features.sharding_info is None or not self._cluster.metadata_request_timeout \ + self._metadata_request_timeout = ( + None + if connection.features.sharding_info is None + or not self._cluster.metadata_request_timeout else datetime.timedelta(seconds=self._cluster.metadata_request_timeout) + ) self._tablets_routing_v1 = connection.features.tablets_routing_v1 @@ -3610,22 +4293,48 @@ def _try_connect(self, endpoint): # _clear_watcher will be called when this ControlConnection is about to be finalized # _watch_callback will get the actual callback from the Connection and relay it to # this object (after a dereferencing a weakref) - self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) + self_weakref = weakref.ref( + self, partial(_clear_watcher, weakref.proxy(connection)) + ) try: - connection.register_watchers({ - "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), - "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), - "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + connection.register_watchers( + { + "TOPOLOGY_CHANGE": partial( + _watch_callback, self_weakref, "_handle_topology_change" + ), + "STATUS_CHANGE": partial( + _watch_callback, self_weakref, "_handle_status_change" + ), + "SCHEMA_CHANGE": partial( + _watch_callback, self_weakref, "_handle_schema_change" + ), + }, + register_timeout=self._timeout, + ) sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) - sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS - peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) - (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout, fail_on_error=False) + sel_local = ( + self._SELECT_LOCAL + if self._token_meta_enabled + else self._SELECT_LOCAL_NO_TOKENS + ) + peers_query = QueryMessage( + query=maybe_add_timeout_to_query( + sel_peers, self._metadata_request_timeout + ), + consistency_level=ConsistencyLevel.ONE, + ) + local_query = QueryMessage( + query=maybe_add_timeout_to_query( + sel_local, self._metadata_request_timeout + ), + consistency_level=ConsistencyLevel.ONE, + ) + (peers_success, peers_result), (local_success, local_result) = ( + connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout, fail_on_error=False + ) + ) if not local_success: raise local_result @@ -3634,14 +4343,23 @@ def _try_connect(self, endpoint): # error with the peers v2 query, fallback to peers v1 self._uses_peers_v2 = False sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) - peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=ConsistencyLevel.ONE) + peers_query = QueryMessage( + query=maybe_add_timeout_to_query( + sel_peers, self._metadata_request_timeout + ), + consistency_level=ConsistencyLevel.ONE, + ) peers_result = connection.wait_for_response( - peers_query, timeout=self._timeout) + peers_query, timeout=self._timeout + ) shared_results = (peers_result, local_result) - self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) - self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) + self._refresh_node_list_and_token_map( + connection, preloaded_results=shared_results + ) + self._refresh_schema( + connection, preloaded_results=shared_results, schema_agreement_wait=-1 + ) except Exception: connection.close() raise @@ -3663,7 +4381,6 @@ def _reconnect(self): schedule = self._cluster.reconnection_policy.new_schedule() with self._reconnection_lock: - # cancel existing reconnection attempts if self._reconnection_handler: self._reconnection_handler.cancel() @@ -3672,9 +4389,12 @@ def _reconnect(self): # will be called with the new connection and then our # _reconnection_handler will be cleared out self._reconnection_handler = _ControlReconnectionHandler( - self, self._cluster.scheduler, schedule, + self, + self._cluster.scheduler, + schedule, self._get_and_set_reconnection_handler, - new_handler=None) + new_handler=None, + ) self._reconnection_handler.start() except Exception: log.debug("[control connection] error reconnecting", exc_info=True) @@ -3727,16 +4447,27 @@ def refresh_schema(self, force=False, **kwargs): self._signal_error() return False - def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): + def _refresh_schema( + self, + connection, + preloaded_results=None, + schema_agreement_wait=None, + force=False, + **kwargs, + ): if self._cluster.is_shutdown: return False - agreed = self.wait_for_schema_agreement(connection, - preloaded_results=preloaded_results, - wait_time=schema_agreement_wait) + agreed = self.wait_for_schema_agreement( + connection, + preloaded_results=preloaded_results, + wait_time=schema_agreement_wait, + ) if not self._schema_meta_enabled and not force: - log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") + log.debug( + "[control connection] Skipping schema refresh because schema metadata is disabled" + ) return False if not agreed: @@ -3748,26 +4479,35 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w self._timeout, fetch_size=self._schema_meta_page_size, metadata_request_timeout=self._metadata_request_timeout, - **kwargs) + **kwargs, + ) return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): try: if self._connection: - self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) + self._refresh_node_list_and_token_map( + self._connection, force_token_rebuild=force_token_rebuild + ) return True except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: - log.debug("[control connection] Error refreshing node list and token map", exc_info=True) + log.debug( + "[control connection] Error refreshing node list and token map", + exc_info=True, + ) self._signal_error() return False - def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, - force_token_rebuild=False): + def _refresh_node_list_and_token_map( + self, connection, preloaded_results=None, force_token_rebuild=False + ): if preloaded_results: - log.debug("[control connection] Refreshing node list and token map using preloaded results") + log.debug( + "[control connection] Refreshing node list and token map using preloaded results" + ) peers_result = preloaded_results[0] local_result = preloaded_results[1] else: @@ -3779,12 +4519,21 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, else: log.debug("[control connection] Refreshing node list and token map") sel_local = self._SELECT_LOCAL - peers_query = QueryMessage(query=maybe_add_timeout_to_query(sel_peers, self._metadata_request_timeout), - consistency_level=cl) - local_query = QueryMessage(query=maybe_add_timeout_to_query(sel_local, self._metadata_request_timeout), - consistency_level=cl) + peers_query = QueryMessage( + query=maybe_add_timeout_to_query( + sel_peers, self._metadata_request_timeout + ), + consistency_level=cl, + ) + local_query = QueryMessage( + query=maybe_add_timeout_to_query( + sel_local, self._metadata_request_timeout + ), + consistency_level=cl, + ) peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout) + peers_query, local_query, timeout=self._timeout + ) peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) @@ -3795,7 +4544,9 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_endpoints = set() if local_result.parsed_rows: - local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) + local_rows = dict_factory( + local_result.column_names, local_result.parsed_rows + ) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name @@ -3808,7 +4559,9 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) - should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None + should_rebuild_token_map = ( + force_token_rebuild or self._cluster.metadata.partitioner is None + ) for row in peers_result: if not self._is_valid_peer(row): continue @@ -3817,11 +4570,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host_id = row.get("host_id") if endpoint in found_endpoints: - log.warning("Found multiple hosts with the same endpoint(%s). Excluding peer %s - %s", endpoint, row.get("peer"), host_id) + log.warning( + "Found multiple hosts with the same endpoint(%s). Excluding peer %s - %s", + endpoint, + row.get("peer"), + host_id, + ) continue if host_id in found_host_ids: - log.warning("Found multiple hosts with the same host_id (%s). Excluding peer %s", host_id, row.get("peer")) + log.warning( + "Found multiple hosts with the same host_id (%s). Excluding peer %s", + host_id, + row.get("peer"), + ) continue found_host_ids.add(host_id) @@ -3833,11 +4595,18 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: - log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) + log.debug( + "[control connection] Updating host ip from %s to %s for (%s)", + host.endpoint, + endpoint, + host_id, + ) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) + self._cluster.on_down( + host, is_host_addition=False, expect_host_to_be_down=True + ) old_endpoint = host.endpoint host.endpoint = endpoint @@ -3845,11 +4614,22 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.on_up(host) if host is None: - log.debug("[control connection] Found new host to connect to: %s", endpoint) - host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) + log.debug( + "[control connection] Found new host to connect to: %s", endpoint + ) + host, _ = self._cluster.add_host( + endpoint, + datacenter=datacenter, + rack=rack, + signal=True, + refresh_nodes=False, + host_id=host_id, + ) should_rebuild_token_map = True else: - should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) + should_rebuild_token_map |= self._update_location_info( + host, datacenter, rack + ) host.host_id = host_id host.broadcast_address = _NodeInfo.get_broadcast_address(row) @@ -3869,12 +4649,19 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, for old_host_id, old_host in self._cluster.metadata.all_hosts_items(): if old_host_id not in found_host_ids: should_rebuild_token_map = True - log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) - self._cluster.metadata.remove_host_by_host_id(old_host_id, old_host.endpoint) + log.debug( + "[control connection] Removing host not found in peers metadata: %r", + old_host, + ) + self._cluster.metadata.remove_host_by_host_id( + old_host_id, old_host.endpoint + ) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: - log.debug("[control connection] Rebuilding token map due to topology changes") + log.debug( + "[control connection] Rebuilding token map due to topology changes" + ) self._cluster.metadata.rebuild_token_map(partitioner, token_map) @staticmethod @@ -3884,32 +4671,37 @@ def _is_valid_peer(row): if not broadcast_rpc: log.warning( - "Found an invalid row for peer - missing broadcast_rpc (full row: %s). Ignoring host." % - row) + "Found an invalid row for peer - missing broadcast_rpc (full row: %s). Ignoring host." + % row + ) return False if not host_id: log.warning( - "Found an invalid row for peer - missing host_id (broadcast_rpc: %s). Ignoring host." % - broadcast_rpc) + "Found an invalid row for peer - missing host_id (broadcast_rpc: %s). Ignoring host." + % broadcast_rpc + ) return False if not row.get("data_center"): log.warning( - "Found an invalid row for peer - missing data_center (broadcast_rpc: %s, host_id: %s). Ignoring host." % - (broadcast_rpc, host_id)) + "Found an invalid row for peer - missing data_center (broadcast_rpc: %s, host_id: %s). Ignoring host." + % (broadcast_rpc, host_id) + ) return False if not row.get("rack"): log.warning( - "Found an invalid row for peer - missing rack (broadcast_rpc: %s, host_id: %s). Ignoring host." % - (broadcast_rpc, host_id)) + "Found an invalid row for peer - missing rack (broadcast_rpc: %s, host_id: %s). Ignoring host." + % (broadcast_rpc, host_id) + ) return False if "tokens" in row and not row.get("tokens"): log.debug( - "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." % - (broadcast_rpc, host_id)) + "Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." + % (broadcast_rpc, host_id) + ) return False return True @@ -3954,8 +4746,12 @@ def _handle_topology_change(self, event): host = self._cluster.metadata.get_host(addr, port) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: - delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) - self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, host) + delay = self._delay_for_event_type( + "topology_change", self._topology_event_refresh_window + ) + self._cluster.scheduler.schedule_unique( + delay, self._refresh_nodes_if_not_up, host + ) elif change_type == "REMOVED_NODE": self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) @@ -3964,12 +4760,18 @@ def _handle_status_change(self, event): addr, port = event["address"] host = self._cluster.metadata.get_host(addr, port) if change_type == "UP": - delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) + delay = self._delay_for_event_type( + "status_change", self._status_event_refresh_window + ) if host is None: # this is the first time we've seen the node - self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) + self._cluster.scheduler.schedule_unique( + delay, self.refresh_node_list_and_token_map + ) else: - self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) + self._cluster.scheduler.schedule_unique( + delay, self._cluster.on_up, host + ) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. @@ -3982,12 +4784,20 @@ def _handle_status_change(self, event): def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return - delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window) + delay = self._delay_for_event_type( + "schema_change", self._schema_event_refresh_window + ) self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) - def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): + def wait_for_schema_agreement( + self, connection=None, preloaded_results=None, wait_time=None + ): - total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait + total_timeout = ( + wait_time + if wait_time is not None + else self._cluster.max_schema_agreement_wait + ) if total_timeout <= 0: return True @@ -4003,11 +4813,15 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai connection = self._connection if preloaded_results: - log.debug("[control connection] Attempting to use preloaded results for schema agreement") + log.debug( + "[control connection] Attempting to use preloaded results for schema agreement" + ) peers_result = preloaded_results[0] local_result = preloaded_results[1] - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + schema_mismatches = self._get_schema_mismatches( + peers_result, local_result, connection.endpoint + ) if schema_mismatches is None: return True @@ -4016,30 +4830,48 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai elapsed = 0 cl = ConsistencyLevel.ONE schema_mismatches = None - select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) + select_peers_query = self._get_peers_query( + self.PeersQueryType.PEERS_SCHEMA, connection + ) while elapsed < total_timeout: - peers_query = QueryMessage(query=maybe_add_timeout_to_query(select_peers_query, self._metadata_request_timeout), - consistency_level=cl) - local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), - consistency_level=cl) + peers_query = QueryMessage( + query=maybe_add_timeout_to_query( + select_peers_query, self._metadata_request_timeout + ), + consistency_level=cl, + ) + local_query = QueryMessage( + query=maybe_add_timeout_to_query( + self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout + ), + consistency_level=cl, + ) try: timeout = min(self._timeout, total_timeout - elapsed) peers_result, local_result = connection.wait_for_responses( - peers_query, local_query, timeout=timeout) + peers_query, local_query, timeout=timeout + ) except OperationTimedOut as timeout: - log.debug("[control connection] Timed out waiting for " - "response during schema agreement check: %s", timeout) + log.debug( + "[control connection] Timed out waiting for " + "response during schema agreement check: %s", + timeout, + ) elapsed = self._time.time() - start continue except ConnectionShutdown: if self._is_shutdown: - log.debug("[control connection] Aborting wait for schema match due to shutdown") + log.debug( + "[control connection] Aborting wait for schema match due to shutdown" + ) return None else: raise - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) + schema_mismatches = self._get_schema_mismatches( + peers_result, local_result, connection.endpoint + ) if schema_mismatches is None: return True @@ -4047,8 +4879,11 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai self._time.sleep(0.2) elapsed = self._time.time() - start - log.warning("Node %s is reporting a schema disagreement: %s", - connection.endpoint, schema_mismatches) + log.warning( + "Node %s is reporting a schema disagreement: %s", + connection.endpoint, + schema_mismatches, + ) return False def _get_schema_mismatches(self, peers_result, local_result, local_address): @@ -4056,12 +4891,14 @@ def _get_schema_mismatches(self, peers_result, local_result, local_address): versions = defaultdict(set) if local_result.parsed_rows: - local_row = dict_factory(local_result.column_names, local_result.parsed_rows)[0] + local_row = dict_factory( + local_result.column_names, local_result.parsed_rows + )[0] if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) for row in peers_result: - schema_ver = row.get('schema_version') + schema_ver = row.get("schema_version") if not schema_ver: continue endpoint = self._cluster.endpoint_factory.create(row) @@ -4090,29 +4927,56 @@ def _get_peers_query(self, peers_query_type, connection=None): - use that to choose the column name for the transport address (see APOLLO-1130), and - use that column name in the provided peers query template. """ - if peers_query_type not in (self.PeersQueryType.PEERS, self.PeersQueryType.PEERS_SCHEMA): + if peers_query_type not in ( + self.PeersQueryType.PEERS, + self.PeersQueryType.PEERS_SCHEMA, + ): raise ValueError("Invalid peers query type: %s" % peers_query_type) if self._uses_peers_v2: if peers_query_type == self.PeersQueryType.PEERS: - query = self._SELECT_PEERS_V2 if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS_V2 + query = ( + self._SELECT_PEERS_V2 + if self._token_meta_enabled + else self._SELECT_PEERS_NO_TOKENS_V2 + ) else: query = self._SELECT_SCHEMA_PEERS_V2 else: - if peers_query_type == self.PeersQueryType.PEERS and self._token_meta_enabled: + if ( + peers_query_type == self.PeersQueryType.PEERS + and self._token_meta_enabled + ): query = self._SELECT_PEERS else: - query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE - if peers_query_type == self.PeersQueryType.PEERS_SCHEMA - else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) - original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint) - host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version - host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version + query_template = ( + self._SELECT_SCHEMA_PEERS_TEMPLATE + if peers_query_type == self.PeersQueryType.PEERS_SCHEMA + else self._SELECT_PEERS_NO_TOKENS_TEMPLATE + ) + original_endpoint_host = self._cluster.metadata.get_host( + connection.original_endpoint + ) + host_release_version = ( + None + if original_endpoint_host is None + else original_endpoint_host.release_version + ) + host_dse_version = ( + None + if original_endpoint_host is None + else original_endpoint_host.dse_version + ) uses_native_address_query = ( - host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) + host_dse_version + and Version(host_dse_version) + >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION + ) if uses_native_address_query: - query = query_template.format(nt_col_name="native_transport_address") + query = query_template.format( + nt_col_name="native_transport_address" + ) elif host_release_version: query = query_template.format(nt_col_name="rpc_address") else: @@ -4133,7 +4997,8 @@ def _signal_error(self): # that errors have already been reported, so we're fine if host: self._cluster.signal_connection_failure( - host, self._connection.last_error, is_host_addition=False) + host, self._connection.last_error, is_host_addition=False + ) return # if the connection is not defunct or the host already left, reconnect @@ -4146,10 +5011,16 @@ def on_up(self, host): def on_down(self, host): conn = self._connection - if conn and conn.endpoint == host.endpoint and \ - self._reconnection_handler is None: - log.debug("[control connection] Control connection host (%s) is " - "considered down, starting reconnection", host) + if ( + conn + and conn.endpoint == host.endpoint + and self._reconnection_handler is None + ): + log.debug( + "[control connection] Control connection host (%s) is " + "considered down, starting reconnection", + host, + ) # this will result in a task being submitted to the executor to reconnect self.reconnect() @@ -4160,18 +5031,23 @@ def on_add(self, host, refresh_nodes=True): def on_remove(self, host): c = self._connection if c and c.endpoint == host.endpoint: - log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) + log.debug( + "[control connection] Control connection host (%s) is being removed. Reconnecting", + host, + ) # refresh will be done on reconnect self.reconnect() else: self.refresh_node_list_and_token_map(force_token_rebuild=True) def get_connections(self): - c = getattr(self, '_connection', None) + c = getattr(self, "_connection", None) return [c] if c else [] def return_connection(self, connection): - if connection is self._connection and (connection.is_defunct or connection.is_closed): + if connection is self._connection and ( + connection.is_defunct or connection.is_closed + ): self.reconnect() @@ -4186,7 +5062,6 @@ def _stop_scheduler(scheduler, thread): class _Scheduler(Thread): - _queue = None _scheduled_tasks = None _executor = None @@ -4240,7 +5115,9 @@ def run(self): run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: - log.debug("Not executing scheduled task due to Scheduler shutdown") + log.debug( + "Not executing scheduled task due to Scheduler shutdown" + ) return if run_at <= time.time(): self._scheduled_tasks.discard(task) @@ -4261,14 +5138,16 @@ def _log_if_failed(self, future): if exc: log.warning( "An internally scheduled tasked failed with an unhandled exception:", - exc_info=exc) + exc_info=exc, + ) def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: - log.debug("Refreshing schema in response to schema change. " - "%s", kwargs) - response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) + log.debug("Refreshing schema in response to schema change. %s", kwargs) + response_future.is_schema_agreed = control_conn._refresh_schema( + connection, **kwargs + ) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) @@ -4348,13 +5227,28 @@ class ResponseFuture(object): _warned_timeout = False - def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, - retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, - speculative_execution_plan=None, continuous_paging_state=None, host=None): + def __init__( + self, + session, + message, + query, + timeout, + metrics=None, + prepared_statement=None, + retry_policy=RetryPolicy(), + row_factory=None, + load_balancer=None, + start_time=None, + speculative_execution_plan=None, + continuous_paging_state=None, + host=None, + ): self.session = session # TODO: normalize handling of retry policy and row factory self.row_factory = row_factory or session.row_factory - self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy + self._load_balancer = ( + load_balancer or session.cluster._default_load_balancing_policy + ) self.message = message self.query = query self.timeout = timeout @@ -4364,7 +5258,9 @@ def __init__(self, session, message, query, timeout, metrics=None, prepared_stat self._callback_lock = Lock() self._start_time = start_time or time.time() self._host = host - self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan + self._spec_execution_plan = ( + speculative_execution_plan or self._spec_execution_plan + ) self._make_query_plan() self._event = Event() self._errors = {} @@ -4385,10 +5281,14 @@ def _start_timer(self): spec_delay = self._spec_execution_plan.next_execution(self._current_host) if spec_delay >= 0: if self._time_remaining is None or self._time_remaining > spec_delay: - self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) + self._timer = self.session.cluster.connection_class.create_timer( + spec_delay, self._on_speculative_execute + ) return if self._time_remaining is not None: - self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) + self._timer = self.session.cluster.connection_class.create_timer( + self._time_remaining, self._on_timeout + ) def _cancel_timer(self): if self._timer: @@ -4405,8 +5305,7 @@ def _on_timeout(self, _attempts=0): # PYTHON-853: for short timeouts, we sometimes race with our __init__ if self._connection is None and _attempts < 3: self._timer = self.session.cluster.connection_class.create_timer( - 0.01, - partial(self._on_timeout, _attempts=_attempts + 1) + 0.01, partial(self._on_timeout, _attempts=_attempts + 1) ) return @@ -4419,7 +5318,9 @@ def _on_timeout(self, _attempts=0): # wait for it endlessly except KeyError: key = "Connection defunct by heartbeat" - errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + errors = { + key: "Client request timeout. See Session.execute[_async](timeout)" + } self._set_final_exception(OperationTimedOut(errors, self._current_host)) return @@ -4432,7 +5333,10 @@ def _on_timeout(self, _attempts=0): # query could get a response from the old query with self._connection.lock: self._connection.orphaned_request_ids.add(self._req_id) - if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + if ( + len(self._connection.orphaned_request_ids) + >= self._connection.orphaned_threshold + ): self._connection.orphaned_threshold_reached = True pool.return_connection(self._connection, stream_was_orphaned=True) @@ -4440,19 +5344,26 @@ def _on_timeout(self, _attempts=0): errors = self._errors if not errors: if self.is_schema_agreed: - key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' - errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + key = ( + str(self._current_host.endpoint) + if self._current_host + else "no host queried before timeout" + ) + errors = { + key: "Client request timeout. See Session.execute[_async](timeout)" + } else: connection = self.session.cluster.control_connection._connection - host = str(connection.endpoint) if connection else 'unknown' - errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} + host = str(connection.endpoint) if connection else "unknown" + errors = { + host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait." + } self._set_final_exception(OperationTimedOut(errors, self._current_host)) def _on_speculative_execute(self): self._timer = None if not self._event.is_set(): - # PYTHON-836, the speculative queries must be after # the query is sent from the main thread, otherwise the # query from the main thread may raise NoHostAvailable @@ -4461,7 +5372,9 @@ def _on_speculative_execute(self): # We reschedule this call until the main thread has succeeded # making a query if not self.attempted_hosts: - self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute) + self._timer = self.session.cluster.connection_class.create_timer( + 0.01, self._on_speculative_execute + ) return if self._time_remaining is not None: @@ -4481,10 +5394,12 @@ def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where # they last left off - self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) + self.query_plan = iter( + self._load_balancer.make_query_plan(self.session.keyspace, self.query) + ) def send_request(self, error_no_hosts=True): - """ Internal """ + """Internal""" # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times for host in self.query_plan: @@ -4492,12 +5407,18 @@ def send_request(self, error_no_hosts=True): if req_id is not None: self._req_id = req_id return True - if self.timeout is not None and time.time() - self._start_time > self.timeout: + if ( + self.timeout is not None + and time.time() - self._start_time > self.timeout + ): self._on_timeout() return True if error_no_hosts: - self._set_final_exception(NoHostAvailable( - "Unable to complete the operation against any hosts", self._errors)) + self._set_final_exception( + NoHostAvailable( + "Unable to complete the operation against any hosts", self._errors + ) + ) return False def _query(self, host, message=None, cb=None): @@ -4506,7 +5427,9 @@ def _query(self, host, message=None, cb=None): pool = self.session._pools.get(host) if not pool: - self._errors[host] = ConnectionException("Host has been marked down or removed") + self._errors[host] = ConnectionException( + "Host has been marked down or removed" + ) return None elif pool.is_shutdown: self._errors[host] = ConnectionException("Pool is shutdown") @@ -4518,23 +5441,39 @@ def _query(self, host, message=None, cb=None): try: # TODO get connectTimeout from cluster settings if self.query: - connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table) + connection, request_id = pool.borrow_connection( + timeout=2.0, + routing_key=self.query.routing_key, + keyspace=self.query.keyspace, + table=self.query.table, + ) else: connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection - result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + result_meta = ( + self.prepared_statement.result_metadata + if self.prepared_statement + else [] + ) if cb is None: cb = partial(self._set_result, host, connection, pool) - self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, - encoder=self._protocol_handler.encode_message, - decoder=self._protocol_handler.decode_message, - result_metadata=result_meta) + self.request_encoded_size = connection.send_msg( + message, + request_id, + cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta, + ) self.attempted_hosts.append(host) return request_id except NoConnectionsAvailable as exc: - log.debug("All connections for host %s are at capacity, moving to the next host", host) + log.debug( + "All connections for host %s are at capacity, moving to the next host", + host, + ) self._errors[host] = exc except ConnectionBusy as exc: log.debug("Connection for host %s is busy, moving to the next host", host) @@ -4575,7 +5514,9 @@ def warnings(self): """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") + raise DriverException( + "warnings cannot be retrieved before ResponseFuture is finalized" + ) return self._warnings @property @@ -4593,7 +5534,9 @@ def custom_payload(self): """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") + raise DriverException( + "custom_payload cannot be retrieved before ResponseFuture is finalized" + ) return self._custom_payload def start_fetching_next_page(self): @@ -4618,7 +5561,9 @@ def start_fetching_next_page(self): self.send_request() def _reprepare(self, prepare_message, host, connection, pool): - cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) + cb = partial( + self.session.submit, self._execute_after_prepare, host, connection, pool + ) request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host @@ -4630,19 +5575,25 @@ def _set_result(self, host, connection, pool, response): if pool and not pool.is_shutdown: pool.return_connection(connection) - trace_id = getattr(response, 'trace_id', None) + trace_id = getattr(response, "trace_id", None) if trace_id: if not self._query_traces: self._query_traces = [] self._query_traces.append(QueryTrace(trace_id, self.session)) - self._warnings = getattr(response, 'warnings', None) - self._custom_payload = getattr(response, 'custom_payload', None) + self._warnings = getattr(response, "warnings", None) + self._custom_payload = getattr(response, "custom_payload", None) - if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload: + if ( + self._custom_payload + and self.session.cluster.control_connection._tablets_routing_v1 + and "tablets-routing-v1" in self._custom_payload + ): protocol = self.session.cluster.protocol_version - info = self._custom_payload.get('tablets-routing-v1') - ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') + info = self._custom_payload.get("tablets-routing-v1") + ctype = types.lookup_casstype( + "TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))" + ) tablet_routing_info = ctype.from_binary(info, protocol) first_token = tablet_routing_info[0] last_token = tablet_routing_info[1] @@ -4650,11 +5601,13 @@ def _set_result(self, host, connection, pool, response): tablet = Tablet.from_row(first_token, last_token, tablet_replicas) keyspace = self.query.keyspace table = self.query.table - self.session.cluster.metadata._tablets.add_tablet(keyspace, table, tablet) + self.session.cluster.metadata._tablets.add_tablet( + keyspace, table, tablet + ) if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: - session = getattr(self, 'session', None) + session = getattr(self, "session", None) # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event @@ -4664,7 +5617,8 @@ def _set_result(self, host, connection, pool, response): # event loop thread. if session: session._set_keyspace_for_all_pools( - response.new_keyspace, self._set_keyspace_completed) + response.new_keyspace, self._set_keyspace_completed + ) elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread @@ -4672,15 +5626,24 @@ def _set_result(self, host, connection, pool, response): self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, - self, connection, **response.schema_change_event) + self, + connection, + **response.schema_change_event, + ) elif response.kind == RESULT_KIND_ROWS: self._paging_state = response.paging_state self._col_names = response.column_names self._col_types = response.column_types - if getattr(self.message, 'continuous_paging_options', None): - self._handle_continuous_paging_first_response(connection, response) + if getattr(self.message, "continuous_paging_options", None): + self._handle_continuous_paging_first_response( + connection, response + ) else: - self._set_final_result(self.row_factory(response.column_names, response.parsed_rows)) + self._set_final_result( + self.row_factory( + response.column_names, response.parsed_rows + ) + ) elif response.kind == RESULT_KIND_VOID: self._set_final_result(None) else: @@ -4692,71 +5655,107 @@ def _set_result(self, host, connection, pool, response): if self._metrics is not None: self._metrics.on_read_timeout() retry = retry_policy.on_read_timeout( - self.query, retry_num=self._query_retries, **response.info) + self.query, retry_num=self._query_retries, **response.info + ) elif isinstance(response, WriteTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_write_timeout() retry = retry_policy.on_write_timeout( - self.query, retry_num=self._query_retries, **response.info) + self.query, retry_num=self._query_retries, **response.info + ) elif isinstance(response, UnavailableErrorMessage): if self._metrics is not None: self._metrics.on_unavailable() retry = retry_policy.on_unavailable( - self.query, retry_num=self._query_retries, **response.info) - elif isinstance(response, (OverloadedErrorMessage, - IsBootstrappingErrorMessage, - TruncateError, ServerError)): + self.query, retry_num=self._query_retries, **response.info + ) + elif isinstance( + response, + ( + OverloadedErrorMessage, + IsBootstrappingErrorMessage, + TruncateError, + ServerError, + ), + ): log.warning("Host %s error: %s.", host, response.summary) if self._metrics is not None: self._metrics.on_other_error() - cl = getattr(self.message, 'consistency_level', None) + cl = getattr(self.message, "consistency_level", None) retry = retry_policy.on_request_error( - self.query, cl, error=response, - retry_num=self._query_retries) + self.query, cl, error=response, retry_num=self._query_retries + ) elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: query_id = self.prepared_statement.query_id - assert query_id == response.info, \ - "Got different query ID in server response (%s) than we " \ + assert query_id == response.info, ( + "Got different query ID in server response (%s) than we " "had before (%s)" % (response.info, query_id) + ) else: query_id = response.info try: - prepared_statement = self.session.cluster._prepared_statements[query_id] + prepared_statement = self.session.cluster._prepared_statements[ + query_id + ] except KeyError: if not self.prepared_statement: - log.error("Tried to execute unknown prepared statement: id=%s", - query_id.encode('hex')) + log.error( + "Tried to execute unknown prepared statement: id=%s", + query_id.encode("hex"), + ) self._set_final_exception(response) return else: prepared_statement = self.prepared_statement - self.session.cluster._prepared_statements[query_id] = prepared_statement + self.session.cluster._prepared_statements[query_id] = ( + prepared_statement + ) current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace - if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ - and prepared_keyspace and current_keyspace != prepared_keyspace: + if ( + not ProtocolVersion.uses_keyspace_flag( + self.session.cluster.protocol_version + ) + and prepared_keyspace + and current_keyspace != prepared_keyspace + ): self._set_final_exception( - ValueError("The Session's current keyspace (%s) does " - "not match the keyspace the statement was " - "prepared with (%s)" % - (current_keyspace, prepared_keyspace))) + ValueError( + "The Session's current keyspace (%s) does " + "not match the keyspace the statement was " + "prepared with (%s)" + % (current_keyspace, prepared_keyspace) + ) + ) return - log.debug("Re-preparing unrecognized prepared statement against host %s: %s", - host, prepared_statement.query_string) - prepared_keyspace = prepared_statement.keyspace \ - if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None - prepare_message = PrepareMessage(query=prepared_statement.query_string, - keyspace=prepared_keyspace) + log.debug( + "Re-preparing unrecognized prepared statement against host %s: %s", + host, + prepared_statement.query_string, + ) + prepared_keyspace = ( + prepared_statement.keyspace + if ProtocolVersion.uses_keyspace_flag( + self.session.cluster.protocol_version + ) + else None + ) + prepare_message = PrepareMessage( + query=prepared_statement.query_string, + keyspace=prepared_keyspace, + ) # since this might block, run on the executor to avoid hanging # the event loop thread - self.session.submit(self._reprepare, prepare_message, host, connection, pool) + self.session.submit( + self._reprepare, prepare_message, host, connection, pool + ) return else: - if hasattr(response, 'to_exception'): + if hasattr(response, "to_exception"): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) @@ -4768,12 +5767,13 @@ def _set_result(self, host, connection, pool, response): self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) - cl = getattr(self.message, 'consistency_level', None) + cl = getattr(self.message, "consistency_level", None) retry = self._retry_policy.on_request_error( - self.query, cl, error=response, retry_num=self._query_retries) + self.query, cl, error=response, retry_num=self._query_retries + ) self._handle_retry_decision(retry, response, host) elif isinstance(response, Exception): - if hasattr(response, 'to_exception'): + if hasattr(response, "to_exception"): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) @@ -4786,14 +5786,18 @@ def _set_result(self, host, connection, pool, response): self._set_final_exception(exc) except Exception as exc: # almost certainly caused by a bug, but we need to set something here - log.exception("Unexpected exception while handling result in ResponseFuture:") + log.exception( + "Unexpected exception while handling result in ResponseFuture:" + ) self._set_final_exception(exc) def _handle_continuous_paging_first_response(self, connection, response): - self._continuous_paging_session = connection.new_continuous_paging_session(response.stream_id, - self._protocol_handler.decode_message, - self.row_factory, - self._continuous_paging_state) + self._continuous_paging_session = connection.new_continuous_paging_session( + response.stream_id, + self._protocol_handler.decode_message, + self.row_factory, + self._continuous_paging_state, + ) self._continuous_paging_session.on_message(response) self._set_final_result(self._continuous_paging_session.results()) @@ -4801,8 +5805,11 @@ def _set_keyspace_completed(self, errors): if not errors: self._set_final_result(None) else: - self._set_final_exception(ConnectionException( - "Failed to set keyspace on all hosts: %s" % (errors,))) + self._set_final_exception( + ConnectionException( + "Failed to set keyspace on all hosts: %s" % (errors,) + ) + ) def _execute_after_prepare(self, host, connection, pool, response): """ @@ -4819,14 +5826,17 @@ def _execute_after_prepare(self, host, connection, pool, response): if response.kind == RESULT_KIND_PREPARED: if self.prepared_statement: if self.prepared_statement.query_id != response.query_id: - self._set_final_exception(DriverException( - "ID mismatch while trying to reprepare (expected {expected}, got {got}). " - "This prepared statement won't work anymore. " - "This usually happens when you run a 'USE...' " - "query after the statement was prepared.".format( - expected=hexlify(self.prepared_statement.query_id), got=hexlify(response.query_id) + self._set_final_exception( + DriverException( + "ID mismatch while trying to reprepare (expected {expected}, got {got}). " + "This prepared statement won't work anymore. " + "This usually happens when you run a 'USE...' " + "query after the statement was prepared.".format( + expected=hexlify(self.prepared_statement.query_id), + got=hexlify(response.query_id), + ) ) - )) + ) self.prepared_statement.result_metadata = response.column_metadata new_metadata_id = response.result_metadata_id if new_metadata_id is not None: @@ -4839,24 +5849,33 @@ def _execute_after_prepare(self, host, connection, pool, response): # this host errored out, move on to the next self.send_request() else: - self._set_final_exception(ConnectionException( - "Got unexpected response when preparing statement " - "on host %s: %s" % (host, response))) + self._set_final_exception( + ConnectionException( + "Got unexpected response when preparing statement " + "on host %s: %s" % (host, response) + ) + ) elif isinstance(response, ErrorMessage): - if hasattr(response, 'to_exception'): + if hasattr(response, "to_exception"): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) elif isinstance(response, ConnectionException): - log.debug("Connection error when preparing statement on host %s: %s", - host, response) + log.debug( + "Connection error when preparing statement on host %s: %s", + host, + response, + ) # try again on a different host, preparing again if necessary self._errors[host] = response self.send_request() else: - self._set_final_exception(ConnectionException( - "Got unexpected response type when preparing " - "statement on host %s: %s" % (host, response))) + self._set_final_exception( + ConnectionException( + "Got unexpected response type when preparing " + "statement on host %s: %s" % (host, response) + ) + ) def _set_final_result(self, response): self._cancel_timer() @@ -4904,10 +5923,11 @@ def _set_final_exception(self, response): def _handle_retry_decision(self, retry_decision, response, host): def exception_from_response(response): - if hasattr(response, 'to_exception'): + if hasattr(response, "to_exception"): return response.to_exception() else: return response + if len(retry_decision) == 2: retry_type, consistency = retry_decision delay = 0 @@ -4939,7 +5959,9 @@ def _retry(self, reuse_connection, consistency_level, host, delay): self.message.consistency_level = consistency_level # don't retry on the event loop thread - self.session.cluster.scheduler.schedule(delay, self._retry_task, reuse_connection, host) + self.session.cluster.scheduler.schedule( + delay, self._retry_task, reuse_connection, host + ) def _retry_task(self, reuse_connection, host): if self._final_exception: @@ -5006,19 +6028,27 @@ def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ if self._final_result is _NOT_SET and self._final_exception is None: raise TraceUnavailable( - "Trace information was not available. The ResponseFuture is not done.") + "Trace information was not available. The ResponseFuture is not done." + ) if self._query_traces: - return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) + return self._get_query_trace( + len(self._query_traces) - 1, max_wait, query_cl + ) - def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): + def get_all_query_traces( + self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE + ): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: - return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] + return [ + self._get_query_trace(i, max_wait_per, query_cl) + for i in range(len(self._query_traces)) + ] return [] def _get_query_trace(self, i, max_wait, query_cl): @@ -5094,9 +6124,15 @@ def add_errback(self, fn, *args, **kwargs): fn(self._final_exception, *args, **kwargs) return self - def add_callbacks(self, callback, errback, - callback_args=(), callback_kwargs=None, - errback_args=(), errback_kwargs=None): + def add_callbacks( + self, + callback, + errback, + callback_args=(), + callback_kwargs=None, + errback_args=(), + errback_kwargs=None, + ): """ A convenient combination of :meth:`.add_callback()` and :meth:`.add_errback()`. @@ -5128,9 +6164,20 @@ def clear_callbacks(self): self._errbacks = [] def __str__(self): - result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result - return "" \ - % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) + result = ( + "(no result yet)" if self._final_result is _NOT_SET else self._final_result + ) + return ( + "" + % ( + self.query, + self._req_id, + result, + self._final_exception, + self.coordinator_host, + ) + ) + __repr__ = __str__ @@ -5142,6 +6189,7 @@ class QueryExhausted(Exception): .. versionadded:: 2.0.0 """ + pass @@ -5252,7 +6300,9 @@ def fetch_next_page(self): if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() - self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form + self._current_rows = ( + result._current_rows + ) # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] @@ -5261,7 +6311,9 @@ def _set_current_rows(self, result): self._current_rows = [result] if result else [] return try: - iter(result) # can't check directly for generator types because cython generators are different + iter( + result + ) # can't check directly for generator types because cython generators are different self._current_rows = result except TypeError: self._current_rows = [result] if result else [] @@ -5274,9 +6326,14 @@ def _enter_list_mode(self, operator): if self._list_mode: return if self._page_iter: - raise RuntimeError("Cannot use %s when results have been iterated." % operator) + raise RuntimeError( + "Cannot use %s when results have been iterated." % operator + ) if self.response_future.has_more_pages: - log.warning("Using %s on paged results causes entire result set to be materialized.", operator) + log.warning( + "Using %s on paged results causes entire result set to be materialized.", + operator, + ) self._fetch_all() # done regardless of paging status in case the row factory produces a generator self._list_mode = True @@ -5286,8 +6343,11 @@ def __eq__(self, other): def __getitem__(self, i): if i == 0: - warn("ResultSet indexing support will be removed in 4.0. Consider using " - "ResultSet.one() to get a single row.", DeprecationWarning) + warn( + "ResultSet indexing support will be removed in 4.0. Consider using " + "ResultSet.one() to get a single row.", + DeprecationWarning, + ) self._enter_list_mode("index operator") return self._current_rows[i] @@ -5314,9 +6374,11 @@ def cancel_continuous_paging(self): try: self.response_future._continuous_paging_session.cancel() except AttributeError: - raise DriverException("Attempted to cancel paging with no active session. This is only for requests with ContinuousdPagingOptions.") + raise DriverException( + "Attempted to cancel paging with no active session. This is only for requests with ContinuousdPagingOptions." + ) - batch_regex = re.compile(r'^\s*BEGIN\s+[a-zA-Z]*\s*BATCH') + batch_regex = re.compile(r"^\s*BEGIN\s+[a-zA-Z]*\s*BATCH") @property def was_applied(self): @@ -5329,22 +6391,36 @@ def was_applied(self): Only valid when one of the of the internal row factories is in use. """ - if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): - raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) + if self.response_future.row_factory not in ( + named_tuple_factory, + dict_factory, + tuple_factory, + ): + raise RuntimeError( + "Cannot determine LWT result with row factory %s" + % (self.response_future.row_factory,) + ) - is_batch_statement = isinstance(self.response_future.query, BatchStatement) \ - or (isinstance(self.response_future.query, SimpleStatement) and self.batch_regex.match(self.response_future.query.query_string)) - if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): + is_batch_statement = isinstance(self.response_future.query, BatchStatement) or ( + isinstance(self.response_future.query, SimpleStatement) + and self.batch_regex.match(self.response_future.query.query_string) + ) + if is_batch_statement and ( + not self.column_names or self.column_names[0] != "[applied]" + ): raise RuntimeError("No LWT were present in the BatchStatement") if not is_batch_statement and len(self.current_rows) != 1: - raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) + raise RuntimeError( + "LWT result should have exactly one row. This has %d." + % (len(self.current_rows)) + ) row = self.current_rows[0] if isinstance(row, tuple): return row[0] else: - return row['[applied]'] + return row["[applied]"] @property def paging_state(self): From 11db11c56570c67d5374220da97d79917b89b301 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:02:09 +0200 Subject: [PATCH 05/18] refactor: rename __unicode__ to __str__ in cqlengine/statements.py Python 3 uses __str__ for string representation; __unicode__ was the Python 2 equivalent for unicode strings. The UnicodeMixin base class currently bridges the two by wiring __str__ to call __unicode__(), but this indirection is unnecessary on Python 3. Rename all 18 __unicode__ method definitions in statements.py directly to __str__, and update the one direct __unicode__() call in __repr__ to __str__(). The __str__ definitions on the subclasses now take precedence over the inherited UnicodeMixin.__str__ lambda, so behavior is unchanged. --- cassandra/cqlengine/statements.py | 413 +++++++++++++++++------------- 1 file changed, 239 insertions(+), 174 deletions(-) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 4782fdccd8..b1e4e43329 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -19,7 +19,12 @@ from cassandra.cqlengine import columns from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue -from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator +from cassandra.cqlengine.operators import ( + BaseWhereOperator, + InOperator, + EqualsOperator, + IsNotNullOperator, +) class StatementException(Exception): @@ -27,18 +32,24 @@ class StatementException(Exception): class ValueQuoter(UnicodeMixin): - def __init__(self, value): self.value = value - def __unicode__(self): + def __str__(self): from cassandra.encoder import cql_quote + if isinstance(self.value, (list, tuple)): - return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' + return "[" + ", ".join([cql_quote(v) for v in self.value]) + "]" elif isinstance(self.value, dict): - return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' + return ( + "{" + + ", ".join( + [cql_quote(k) + ":" + cql_quote(v) for k, v in self.value.items()] + ) + + "}" + ) elif isinstance(self.value, set): - return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' + return "{" + ", ".join([cql_quote(v) for v in self.value]) + "}" return cql_quote(self.value) def __eq__(self, other): @@ -48,20 +59,19 @@ def __eq__(self, other): class InQuoter(ValueQuoter): - - def __unicode__(self): + def __str__(self): from cassandra.encoder import cql_quote - return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' + return "(" + ", ".join([cql_quote(v) for v in self.value]) + ")" -class BaseClause(UnicodeMixin): +class BaseClause(UnicodeMixin): def __init__(self, field, value): self.field = field self.value = value self.context_id = None - def __unicode__(self): + def __str__(self): raise NotImplementedError def __hash__(self): @@ -76,21 +86,21 @@ def __ne__(self, other): return not self.__eq__(other) def get_context_size(self): - """ returns the number of entries this clause will add to the query context """ + """returns the number of entries this clause will add to the query context""" return 1 def set_context_id(self, i): - """ sets the value placeholder that will be used in the query """ + """sets the value placeholder that will be used in the query""" self.context_id = i def update_context(self, ctx): - """ updates the query context with this clauses values """ + """updates the query context with this clauses values""" assert isinstance(ctx, dict) ctx[str(self.context_id)] = self.value class WhereClause(BaseClause): - """ a single where statement used in queries """ + """a single where statement used in queries""" def __init__(self, field, operator, value, quote_field=True): """ @@ -103,16 +113,20 @@ def __init__(self, field, operator, value, quote_field=True): """ if not isinstance(operator, BaseWhereOperator): raise StatementException( - "operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator)) + "operator must be of type {0}, got {1}".format( + BaseWhereOperator, type(operator) + ) ) super(WhereClause, self).__init__(field, value) self.operator = operator - self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + self.query_value = ( + self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + ) self.quote_field = quote_field - def __unicode__(self): - field = ('"{0}"' if self.quote_field else '{0}').format(self.field) - return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value)) + def __str__(self): + field = ('"{0}"' if self.quote_field else "{0}").format(self.field) + return "{0} {1} {2}".format(field, self.operator, str(self.query_value)) def __hash__(self): return super(WhereClause, self).__hash__() ^ hash(self.operator) @@ -138,11 +152,11 @@ def update_context(self, ctx): class IsNotNullClause(WhereClause): def __init__(self, field): - super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '') + super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), "") - def __unicode__(self): - field = ('"{0}"' if self.quote_field else '{0}').format(self.field) - return u'{0} {1}'.format(field, self.operator) + def __str__(self): + field = ('"{0}"' if self.quote_field else "{0}").format(self.field) + return "{0} {1}".format(field, self.operator) def update_context(self, ctx): pass @@ -150,34 +164,34 @@ def update_context(self, ctx): def get_context_size(self): return 0 + # alias for convenience IsNotNull = IsNotNullClause class AssignmentClause(BaseClause): - """ a single variable st statement """ + """a single variable st statement""" - def __unicode__(self): - return u'"{0}" = %({1})s'.format(self.field, self.context_id) + def __str__(self): + return '"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ConditionalClause(BaseClause): - """ A single variable iff statement """ + """A single variable iff statement""" - def __unicode__(self): - return u'"{0}" = %({1})s'.format(self.field, self.context_id) + def __str__(self): + return '"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ContainerUpdateTypeMapMeta(type): - def __init__(cls, name, bases, dct): - if not hasattr(cls, 'type_map'): + if not hasattr(cls, "type_map"): cls.type_map = {} else: cls.type_map[cls.col_type] = cls @@ -185,7 +199,6 @@ def __init__(cls, name, bases, dct): class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta): - def __init__(self, field, value, operation=None, previous=None): super(ContainerUpdateClause, self).__init__(field, value) self.previous = previous @@ -204,20 +217,22 @@ def update_context(self, ctx): class SetUpdateClause(ContainerUpdateClause): - """ updates a set collection """ + """updates a set collection""" col_type = columns.Set _additions = None _removals = None - def __unicode__(self): + def __str__(self): qs = [] ctx_id = self.context_id - if (self.previous is None and - self._assignments is None and - self._additions is None and - self._removals is None): + if ( + self.previous is None + and self._assignments is None + and self._additions is None + and self._removals is None + ): qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] if self._assignments is not None: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] @@ -228,10 +243,10 @@ def __unicode__(self): if self._removals is not None: qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] - return ', '.join(qs) + return ", ".join(qs) def _analyze(self): - """ works out the updates to be performed """ + """works out the updates to be performed""" if self.value is None or self.value == self.previous: pass elif self._operation == "add": @@ -249,21 +264,29 @@ def _analyze(self): def get_context_size(self): if not self._analyzed: self._analyze() - if (self.previous is None and - not self._assignments and - self._additions is None and - self._removals is None): + if ( + self.previous is None + and not self._assignments + and self._additions is None + and self._removals is None + ): return 1 - return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) + return ( + int(bool(self._assignments)) + + int(bool(self._additions)) + + int(bool(self._removals)) + ) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id - if (self.previous is None and - self._assignments is None and - self._additions is None and - self._removals is None): + if ( + self.previous is None + and self._assignments is None + and self._additions is None + and self._removals is None + ): ctx[str(ctx_id)] = set() if self._assignments is not None: ctx[str(ctx_id)] = self._assignments @@ -276,14 +299,14 @@ def update_context(self, ctx): class ListUpdateClause(ContainerUpdateClause): - """ updates a list collection """ + """updates a list collection""" col_type = columns.List _append = None _prepend = None - def __unicode__(self): + def __str__(self): if not self._analyzed: self._analyze() qs = [] @@ -299,12 +322,16 @@ def __unicode__(self): if self._append is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] - return ', '.join(qs) + return ", ".join(qs) def get_context_size(self): if not self._analyzed: self._analyze() - return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) + return ( + int(self._assignments is not None) + + int(bool(self._append)) + + int(bool(self._prepend)) + ) def update_context(self, ctx): if not self._analyzed: @@ -320,7 +347,7 @@ def update_context(self, ctx): ctx[str(ctx_id)] = self._append def _analyze(self): - """ works out the updates to be performed """ + """works out the updates to be performed""" if self.value is None or self.value == self.previous: pass @@ -343,7 +370,6 @@ def _analyze(self): # list, do a complete insert self._assignments = self.value else: - # the max start idx we want to compare search_space = len(self.value) - max(0, len(self.previous) - 1) @@ -369,7 +395,7 @@ def _analyze(self): class MapUpdateClause(ContainerUpdateClause): - """ updates a map collection """ + """updates a map collection""" col_type = columns.Map @@ -385,7 +411,12 @@ def _analyze(self): if self.previous is None: self._updates = sorted([k for k, v in self.value.items()]) else: - self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None + self._updates = ( + sorted( + [k for k, v in self.value.items() if v != self.previous.get(k)] + ) + or None + ) self._analyzed = True def get_context_size(self): @@ -412,7 +443,7 @@ def is_assignment(self): self._analyze() return self.previous is None and not self._updates and not self._removals - def __unicode__(self): + def __str__(self): qs = [] ctx_id = self.context_id @@ -423,14 +454,15 @@ def __unicode__(self): ctx_id += 1 else: for _ in self._updates or []: - qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] + qs += [ + '"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1) + ] ctx_id += 2 - return ', '.join(qs) + return ", ".join(qs) class CounterUpdateClause(AssignmentClause): - col_type = columns.Counter def __init__(self, field, value, previous=None): @@ -443,9 +475,9 @@ def get_context_size(self): def update_context(self, ctx): ctx[str(self.context_id)] = abs(self.value - self.previous) - def __unicode__(self): + def __str__(self): delta = self.value - self.previous - sign = '-' if delta < 0 else '+' + sign = "-" if delta < 0 else "+" return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) @@ -454,12 +486,12 @@ class BaseDeleteClause(BaseClause): class FieldDeleteClause(BaseDeleteClause): - """ deletes a field from a row """ + """deletes a field from a row""" def __init__(self, field): super(FieldDeleteClause, self).__init__(field, None) - def __unicode__(self): + def __str__(self): return '"{0}"'.format(self.field) def update_context(self, ctx): @@ -470,7 +502,7 @@ def get_context_size(self): class MapDeleteClause(BaseDeleteClause): - """ removes keys from a map """ + """removes keys from a map""" def __init__(self, field, value, previous=None): super(MapDeleteClause, self).__init__(field, value) @@ -494,16 +526,23 @@ def get_context_size(self): self._analyze() return len(self._removals) - def __unicode__(self): + def __str__(self): if not self._analyzed: self._analyze() - return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) + return ", ".join( + [ + '"{0}"[%({1})s]'.format(self.field, self.context_id + i) + for i in range(len(self._removals)) + ] + ) class BaseCQLStatement(UnicodeMixin): - """ The base cql statement class """ + """The base cql statement class""" - def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): + def __init__( + self, table, timestamp=None, where=None, fetch_size=None, conditionals=None + ): super(BaseCQLStatement, self).__init__() self.table = table self.context_id = 0 @@ -525,7 +564,11 @@ def _update_part_key_values(self, field_index_map, clauses, parts): def partition_key_values(self, field_index_map): parts = [None] * len(field_index_map) - self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) + self._update_part_key_values( + field_index_map, + (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), + parts, + ) return parts def add_where(self, column, operator, value, quote_field=True): @@ -560,7 +603,7 @@ def add_conditional_clause(self, clause): self.conditionals.append(clause) def _get_conditionals(self): - return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals])) + return "IF {0}".format(" AND ".join([str(c) for c in self.conditionals])) def get_context_size(self): return len(self.get_context()) @@ -589,42 +632,39 @@ def timestamp_normalized(self): else: tmp = self.timestamp - return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) + return int(time.mktime(tmp.timetuple()) * 1e6 + tmp.microsecond) - def __unicode__(self): + def __str__(self): raise NotImplementedError def __repr__(self): - return self.__unicode__() + return self.__str__() @property def _where(self): - return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses])) + return "WHERE {0}".format(" AND ".join([str(c) for c in self.where_clauses])) class SelectStatement(BaseCQLStatement): - """ a cql select statement """ - - def __init__(self, - table, - fields=None, - count=False, - where=None, - order_by=None, - limit=None, - allow_filtering=False, - distinct_fields=None, - fetch_size=None): - + """a cql select statement""" + + def __init__( + self, + table, + fields=None, + count=False, + where=None, + order_by=None, + limit=None, + allow_filtering=False, + distinct_fields=None, + fetch_size=None, + ): """ :param where :type where list of cqlengine.statements.WhereClause """ - super(SelectStatement, self).__init__( - table, - where=where, - fetch_size=fetch_size - ) + super(SelectStatement, self).__init__(table, where=where, fetch_size=fetch_size) self.fields = [fields] if isinstance(fields, str) else (fields or []) self.distinct_fields = distinct_fields @@ -633,48 +673,60 @@ def __init__(self, self.limit = limit self.allow_filtering = allow_filtering - def __unicode__(self): - qs = ['SELECT'] + def __str__(self): + qs = ["SELECT"] if self.distinct_fields: if self.count: - qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] + qs += [ + "DISTINCT COUNT({0})".format( + ", ".join(['"{0}"'.format(f) for f in self.distinct_fields]) + ) + ] else: - qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] + qs += [ + "DISTINCT {0}".format( + ", ".join(['"{0}"'.format(f) for f in self.distinct_fields]) + ) + ] elif self.count: - qs += ['COUNT(*)'] + qs += ["COUNT(*)"] else: - qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*'] - qs += ['FROM', self.table] + qs += [ + ", ".join(['"{0}"'.format(f) for f in self.fields]) + if self.fields + else "*" + ] + qs += ["FROM", self.table] if self.where_clauses: qs += [self._where] if self.order_by and not self.count: - qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))] + qs += ["ORDER BY {0}".format(", ".join(str(o) for o in self.order_by))] if self.limit: - qs += ['LIMIT {0}'.format(self.limit)] + qs += ["LIMIT {0}".format(self.limit)] if self.allow_filtering: - qs += ['ALLOW FILTERING'] + qs += ["ALLOW FILTERING"] - return ' '.join(qs) + return " ".join(qs) class AssignmentStatement(BaseCQLStatement): - """ value assignment statements """ - - def __init__(self, - table, - assignments=None, - where=None, - ttl=None, - timestamp=None, - conditionals=None): + """value assignment statements""" + + def __init__( + self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + conditionals=None, + ): super(AssignmentStatement, self).__init__( - table, - where=where, - conditionals=conditionals + table, where=where, conditionals=conditionals ) self.ttl = ttl self.timestamp = timestamp @@ -717,33 +769,33 @@ def get_context(self): class InsertStatement(AssignmentStatement): - """ an cql insert statement """ - - def __init__(self, - table, - assignments=None, - where=None, - ttl=None, - timestamp=None, - if_not_exists=False): - super(InsertStatement, self).__init__(table, - assignments=assignments, - where=where, - ttl=ttl, - timestamp=timestamp) + """an cql insert statement""" + + def __init__( + self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + if_not_exists=False, + ): + super(InsertStatement, self).__init__( + table, assignments=assignments, where=where, ttl=ttl, timestamp=timestamp + ) self.if_not_exists = if_not_exists - def __unicode__(self): - qs = ['INSERT INTO {0}'.format(self.table)] + def __str__(self): + qs = ["INSERT INTO {0}".format(self.table)] # get column names and context placeholders fields = [a.insert_tuple() for a in self.assignments] columns, values = zip(*fields) - qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))] - qs += ['VALUES'] - qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))] + qs += ["({0})".format(", ".join(['"{0}"'.format(c) for c in columns]))] + qs += ["VALUES"] + qs += ["({0})".format(", ".join(["%({0})s".format(v) for v in values]))] if self.if_not_exists: qs += ["IF NOT EXISTS"] @@ -757,31 +809,35 @@ def __unicode__(self): if using_options: qs += ["USING {}".format(" AND ".join(using_options))] - return ' '.join(qs) + return " ".join(qs) class UpdateStatement(AssignmentStatement): - """ an cql update select statement """ - - def __init__(self, - table, - assignments=None, - where=None, - ttl=None, - timestamp=None, - conditionals=None, - if_exists=False): - super(UpdateStatement, self). __init__(table, - assignments=assignments, - where=where, - ttl=ttl, - timestamp=timestamp, - conditionals=conditionals) + """an cql update select statement""" + + def __init__( + self, + table, + assignments=None, + where=None, + ttl=None, + timestamp=None, + conditionals=None, + if_exists=False, + ): + super(UpdateStatement, self).__init__( + table, + assignments=assignments, + where=where, + ttl=ttl, + timestamp=timestamp, + conditionals=conditionals, + ) self.if_exists = if_exists - def __unicode__(self): - qs = ['UPDATE', self.table] + def __str__(self): + qs = ["UPDATE", self.table] using_options = [] @@ -794,8 +850,8 @@ def __unicode__(self): if using_options: qs += ["USING {0}".format(" AND ".join(using_options))] - qs += ['SET'] - qs += [', '.join([str(c) for c in self.assignments])] + qs += ["SET"] + qs += [", ".join([str(c) for c in self.assignments])] if self.where_clauses: qs += [self._where] @@ -806,7 +862,7 @@ def __unicode__(self): if self.if_exists: qs += ["IF EXISTS"] - return ' '.join(qs) + return " ".join(qs) def get_context(self): ctx = super(UpdateStatement, self).get_context() @@ -822,13 +878,15 @@ def update_context_id(self, i): def add_update(self, column, value, operation=None, previous=None): # For remove all values are None, no need to convert them - if operation != 'remove': + if operation != "remove": value = column.to_database(value) col_type = type(column) container_update_type = ContainerUpdateClause.type_map.get(col_type) if container_update_type: previous = column.to_database(previous) - clause = container_update_type(column.db_field_name, value, operation, previous) + clause = container_update_type( + column.db_field_name, value, operation, previous + ) elif col_type == columns.Counter: clause = CounterUpdateClause(column.db_field_name, value, previous) else: @@ -838,14 +896,19 @@ def add_update(self, column, value, operation=None, previous=None): class DeleteStatement(BaseCQLStatement): - """ a cql delete statement """ - - def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): + """a cql delete statement""" + + def __init__( + self, + table, + fields=None, + where=None, + timestamp=None, + conditionals=None, + if_exists=False, + ): super(DeleteStatement, self).__init__( - table, - where=where, - timestamp=timestamp, - conditionals=conditionals + table, where=where, timestamp=timestamp, conditionals=conditionals ) self.fields = [] if isinstance(fields, str): @@ -876,16 +939,18 @@ def add_field(self, field): if isinstance(field, str): field = FieldDeleteClause(field) if not isinstance(field, BaseClause): - raise StatementException("only instances of AssignmentClause can be added to statements") + raise StatementException( + "only instances of AssignmentClause can be added to statements" + ) field.set_context_id(self.context_counter) self.context_counter += field.get_context_size() self.fields.append(field) - def __unicode__(self): - qs = ['DELETE'] + def __str__(self): + qs = ["DELETE"] if self.fields: - qs += [', '.join(['{0}'.format(f) for f in self.fields])] - qs += ['FROM', self.table] + qs += [", ".join(["{0}".format(f) for f in self.fields])] + qs += ["FROM", self.table] delete_option = [] @@ -904,4 +969,4 @@ def __unicode__(self): if self.if_exists: qs += ["IF EXISTS"] - return ' '.join(qs) + return " ".join(qs) From d7770b8dbeb000b3abc549f8aae07c53f27d9611 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:02:54 +0200 Subject: [PATCH 06/18] refactor: rename __unicode__ to __str__ in cqlengine/query.py Rename the 2 __unicode__ methods in AbstractQueryableColumn and ModelQuerySet to __str__. Remove the redundant __str__ wrapper in ModelQuerySet that existed solely to bridge __str__ -> __unicode__ for Python 2 compatibility. In Python 3, __str__ is the canonical string representation method; __unicode__ served that role in Python 2. The indirection through UnicodeMixin is no longer needed for these classes. --- cassandra/cqlengine/query.py | 584 +++++++++++++++++++++++++---------- 1 file changed, 414 insertions(+), 170 deletions(-) diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index afc7ceeef6..0d492bc016 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -19,15 +19,34 @@ from warnings import warn from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement -from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin +from cassandra.cqlengine import ( + columns, + CQLEngineException, + ValidationError, + UnicodeMixin, +) from cassandra.cqlengine import connection as conn from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue -from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, - GreaterThanOrEqualOperator, LessThanOperator, - LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) -from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, - UpdateStatement, InsertStatement, - BaseCQLStatement, MapDeleteClause, ConditionalClause) +from cassandra.cqlengine.operators import ( + InOperator, + EqualsOperator, + GreaterThanOperator, + GreaterThanOrEqualOperator, + LessThanOperator, + LessThanOrEqualOperator, + ContainsOperator, + BaseWhereOperator, +) +from cassandra.cqlengine.statements import ( + WhereClause, + SelectStatement, + DeleteStatement, + UpdateStatement, + InsertStatement, + BaseCQLStatement, + MapDeleteClause, + ConditionalClause, +) class QueryException(CQLEngineException): @@ -51,6 +70,7 @@ class LWTException(CQLEngineException): :param existing: The current state of the data which prevented the write. """ + def __init__(self, existing): super(LWTException, self).__init__("LWT Query was not applied") self.existing = existing @@ -87,7 +107,7 @@ class AbstractQueryableColumn(UnicodeMixin): def _get_column(self): raise NotImplementedError - def __unicode__(self): + def __str__(self): raise NotImplementedError def _to_database(self, val): @@ -110,7 +130,6 @@ def contains_(self, item): """ return WhereClause(str(self), ContainsOperator(), item) - def __eq__(self, other): return WhereClause(str(self), EqualsOperator(), self._to_database(other)) @@ -118,18 +137,22 @@ def __gt__(self, other): return WhereClause(str(self), GreaterThanOperator(), self._to_database(other)) def __ge__(self, other): - return WhereClause(str(self), GreaterThanOrEqualOperator(), self._to_database(other)) + return WhereClause( + str(self), GreaterThanOrEqualOperator(), self._to_database(other) + ) def __lt__(self, other): return WhereClause(str(self), LessThanOperator(), self._to_database(other)) def __le__(self, other): - return WhereClause(str(self), LessThanOrEqualOperator(), self._to_database(other)) + return WhereClause( + str(self), LessThanOrEqualOperator(), self._to_database(other) + ) class BatchType(object): - Unlogged = 'UNLOGGED' - Counter = 'COUNTER' + Unlogged = "UNLOGGED" + Counter = "COUNTER" class BatchQuery(object): @@ -140,6 +163,7 @@ class BatchQuery(object): See :doc:`/cqlengine/batches` for more details. """ + warn_multiple_exec = True _consistency = None @@ -147,9 +171,15 @@ class BatchQuery(object): _connection = None _connection_explicit = False - - def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, - timeout=conn.NOT_SET, connection=None): + def __init__( + self, + batch_type=None, + timestamp=None, + consistency=None, + execute_on_exception=False, + timeout=conn.NOT_SET, + connection=None, + ): """ :param batch_type: (optional) One of batch type values available through BatchType enum :type batch_type: BatchType, str or None @@ -171,7 +201,7 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on self.queries = [] self.batch_type = batch_type if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): - raise CQLEngineException('timestamp object must be an instance of datetime') + raise CQLEngineException("timestamp object must be an instance of datetime") self.timestamp = timestamp self._consistency = consistency self._execute_on_exception = execute_on_exception @@ -185,7 +215,9 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on def add_query(self, query): if not isinstance(query, BaseCQLStatement): - raise CQLEngineException('only BaseCQLStatements can be added to a batch query') + raise CQLEngineException( + "only BaseCQLStatements can be added to a batch query" + ) self.queries.append(query) def consistency(self, consistency): @@ -209,7 +241,11 @@ def add_callback(self, fn, *args, **kwargs): :param kwargs: Named arguments to be passed to the callback at the time of execution """ if not callable(fn): - raise ValueError("Value for argument 'fn' is {0} and is not a callable object.".format(type(fn))) + raise ValueError( + "Value for argument 'fn' is {0} and is not a callable object.".format( + type(fn) + ) + ) self._callbacks.append((fn, args, kwargs)) def execute(self): @@ -227,20 +263,19 @@ def execute(self): return batch_type = None if self.batch_type is CBatchType.LOGGED else self.batch_type - opener = 'BEGIN ' + (str(batch_type) + ' ' if batch_type else '') + ' BATCH' + opener = "BEGIN " + (str(batch_type) + " " if batch_type else "") + " BATCH" if self.timestamp: - if isinstance(self.timestamp, int): ts = self.timestamp elif isinstance(self.timestamp, (datetime, timedelta)): ts = self.timestamp if isinstance(self.timestamp, timedelta): ts += datetime.now() # Apply timedelta - ts = int(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) + ts = int(time.mktime(ts.timetuple()) * 1e6 + ts.microsecond) else: raise ValueError("Batch expects a long, a timedelta, or a datetime") - opener += ' USING TIMESTAMP {0}'.format(ts) + opener += " USING TIMESTAMP {0}".format(ts) query_list = [opener] parameters = {} @@ -249,12 +284,18 @@ def execute(self): query.update_context_id(ctx_counter) ctx = query.get_context() ctx_counter += len(ctx) - query_list.append(' ' + str(query)) + query_list.append(" " + str(query)) parameters.update(ctx) - query_list.append('APPLY BATCH;') + query_list.append("APPLY BATCH;") - tmp = conn.execute('\n'.join(query_list), parameters, self._consistency, self._timeout, connection=self._connection) + tmp = conn.execute( + "\n".join(query_list), + parameters, + self._consistency, + self._timeout, + connection=self._connection, + ) check_applied(tmp) self.queries = [] @@ -305,12 +346,13 @@ def __init__(self, *args, **kwargs): if len(args) < 1: raise ValueError("No model provided.") - keyspace = kwargs.pop('keyspace', None) - connection = kwargs.pop('connection', None) + keyspace = kwargs.pop("keyspace", None) + connection = kwargs.pop("connection", None) if kwargs: - raise ValueError("Unknown keyword argument(s): {0}".format( - ','.join(kwargs.keys()))) + raise ValueError( + "Unknown keyword argument(s): {0}".format(",".join(kwargs.keys())) + ) for model in args: try: @@ -337,7 +379,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class AbstractQuerySet(object): - def __init__(self, model): super(AbstractQuerySet, self).__init__() self.model = model @@ -382,7 +423,7 @@ def __init__(self, model): self._count = None self._batch = None - self._ttl = None + self._ttl = None self._consistency = None self._timestamp = None self._if_not_exists = False @@ -400,16 +441,19 @@ def _execute(self, statement): return self._batch.add_query(statement) else: connection = self._connection or self.model._get_connection() - result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) + result = _execute_statement( + self.model, + statement, + self._consistency, + self._timeout, + connection=connection, + ) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result - def __unicode__(self): - return str(self._select_query()) - def __str__(self): - return str(self.__unicode__()) + return str(self._select_query()) def __call__(self, *args, **kwargs): return self.filter(*args, **kwargs) @@ -417,15 +461,22 @@ def __call__(self, *args, **kwargs): def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution + if k in [ + "_con", + "_cur", + "_result_cache", + "_result_idx", + "_result_generator", + "_construct_result", + ]: # don't clone these, which are per-request-execution clone.__dict__[k] = None - elif k == '_batch': + elif k == "_batch": # we need to keep the same batch instance across # all queryset clones, otherwise the batched queries # fly off into other batch instances which are never # executed, thx @dokai clone.__dict__[k] = self._batch - elif k == '_timeout': + elif k == "_timeout": clone.__dict__[k] = self._timeout else: clone.__dict__[k] = copy.deepcopy(v, memo) @@ -439,11 +490,11 @@ def __len__(self): # ----query generation / execution---- def _select_fields(self): - """ returns the fields to select """ + """returns the fields to select""" return [] def _validate_select_where(self): - """ put select query validation here """ + """put select query validation here""" def _select_query(self): """ @@ -459,18 +510,22 @@ def _select_query(self): limit=self._limit, allow_filtering=self._allow_filtering, distinct_fields=self._distinct_fields, - fetch_size=self._fetch_size + fetch_size=self._fetch_size, ) # ----Reads------ def _execute_query(self): if self._batch: - raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + raise CQLEngineException( + "Only inserts, updates, and deletes are available in batch mode" + ) if self._result_cache is None: self._result_generator = (i for i in self._execute(self._select_query())) self._result_cache = [] - self._construct_result = self._maybe_inject_deferred(self._get_result_constructor()) + self._construct_result = self._maybe_inject_deferred( + self._get_result_constructor() + ) # "DISTINCT COUNT()" is not supported in C* < 2.2, so we need to materialize all results to get # len() and count() working with DISTINCT queries @@ -505,7 +560,9 @@ def _fill_result_cache_to_idx(self, idx): self._result_idx += 1 while True: try: - self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) + self._result_cache[self._result_idx] = self._construct_result( + self._result_cache[self._result_idx] + ) break except IndexError: self._result_cache.append(next(self._result_generator)) @@ -535,8 +592,10 @@ def __getitem__(self, s): start = s.start if s.start else 0 if start < 0 or (s.stop is not None and s.stop < 0): - warn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", - DeprecationWarning) + warn( + "ModelQuerySet slicing with negative indices support will be removed in 4.0.", + DeprecationWarning, + ) # calculate the amount of results that need to be loaded end = s.stop @@ -548,16 +607,18 @@ def __getitem__(self, s): except StopIteration: pass - return self._result_cache[start:s.stop:s.step] + return self._result_cache[start : s.stop : s.step] else: try: s = int(s) except (ValueError, TypeError): - raise TypeError('QuerySet indices must be integers') + raise TypeError("QuerySet indices must be integers") if s < 0: - warn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", - DeprecationWarning) + warn( + "ModelQuerySet indexing with negative indices support will be removed in 4.0.", + DeprecationWarning, + ) # Using negative indexing is costly since we have to execute a count() if s < 0: @@ -583,8 +644,11 @@ def _construct_with_deferred(f, deferred, row): return f(row) def _maybe_inject_deferred(self, constructor): - return partial(self._construct_with_deferred, constructor, self._deferred_values)\ - if self._deferred_values else constructor + return ( + partial(self._construct_with_deferred, constructor, self._deferred_values) + if self._deferred_values + else constructor + ) def batch(self, batch_obj): """ @@ -593,10 +657,12 @@ def batch(self, batch_obj): Note: running a select query with a batch object will raise an exception """ if self._connection: - raise CQLEngineException("Cannot specify the connection on model in batch mode.") + raise CQLEngineException( + "Cannot specify the connection on model in batch mode." + ) if batch_obj is not None and not isinstance(batch_obj, BatchQuery): - raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + raise CQLEngineException("batch_obj must be a BatchQuery instance or None") clone = copy.deepcopy(self) clone._batch = batch_obj return clone @@ -637,11 +703,11 @@ def _parse_filter_arg(self, arg): __ :returns: colname, op tuple """ - statement = arg.rsplit('__', 1) + statement = arg.rsplit("__", 1) if len(statement) == 1: return arg, None elif len(statement) == 2: - return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) + return (statement[0], statement[1]) if arg != "pk__token" else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) @@ -653,7 +719,9 @@ def iff(self, *args, **kwargs): clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, ConditionalClause): - raise QueryException('{0} is not a valid query operator'.format(operator)) + raise QueryException( + "{0} is not a valid query operator".format(operator) + ) clone._conditional.append(operator) for arg, val in kwargs.items(): @@ -664,16 +732,20 @@ def iff(self, *args, **kwargs): try: column = self.model._get_column(col_name) except KeyError: - raise QueryException("Can't resolve column name: '{0}'".format(col_name)) + raise QueryException( + "Can't resolve column name: '{0}'".format(col_name) + ) if isinstance(val, BaseQueryFunction): query_val = val else: query_val = column.to_database(val) - operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator_class = BaseWhereOperator.get_operator(col_op or "EQ") operator = operator_class() - clone._conditional.append(WhereClause(column.db_field_name, operator, query_val)) + clone._conditional.append( + WhereClause(column.db_field_name, operator, query_val) + ) return clone @@ -692,7 +764,9 @@ def filter(self, *args, **kwargs): clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, WhereClause): - raise QueryException('{0} is not a valid query operator'.format(operator)) + raise QueryException( + "{0} is not a valid query operator".format(operator) + ) clone._where.append(operator) for arg, val in kwargs.items(): @@ -703,10 +777,14 @@ def filter(self, *args, **kwargs): try: column = self.model._get_column(col_name) except KeyError: - raise QueryException("Can't resolve column name: '{0}'".format(col_name)) + raise QueryException( + "Can't resolve column name: '{0}'".format(col_name) + ) else: - if col_name != 'pk__token': - raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + if col_name != "pk__token": + raise QueryException( + "Token() values may only be compared to the 'pk__token' virtual column" + ) column = columns._PartitionKeysToken(self.model) quote_field = False @@ -714,31 +792,40 @@ def filter(self, *args, **kwargs): partition_columns = column.partition_columns if len(partition_columns) != len(val.value): raise QueryException( - 'Token() received {0} arguments but model has {1} partition keys'.format( - len(val.value), len(partition_columns))) + "Token() received {0} arguments but model has {1} partition keys".format( + len(val.value), len(partition_columns) + ) + ) val.set_columns(partition_columns) # get query operator, or use equals if not supplied - operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator_class = BaseWhereOperator.get_operator(col_op or "EQ") operator = operator_class() if isinstance(operator, InOperator): if not isinstance(val, (list, tuple)): - raise QueryException('IN queries must use a list/tuple value') + raise QueryException("IN queries must use a list/tuple value") query_val = [column.to_database(v) for v in val] elif isinstance(val, BaseQueryFunction): query_val = val - elif (isinstance(operator, ContainsOperator) and - isinstance(column, (columns.List, columns.Set, columns.Map))): + elif isinstance(operator, ContainsOperator) and isinstance( + column, (columns.List, columns.Set, columns.Map) + ): # For ContainsOperator and collections, we query using the value, not the container query_val = val else: query_val = column.to_database(val) if not col_op: # only equal values should be deferred clone._defer_fields.add(column.db_field_name) - clone._deferred_values[column.db_field_name] = val # map by db field name for substitution in results + clone._deferred_values[column.db_field_name] = ( + val # map by db field name for substitution in results + ) - clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) + clone._where.append( + WhereClause( + column.db_field_name, operator, query_val, quote_field=quote_field + ) + ) return clone @@ -766,7 +853,7 @@ def get(self, *args, **kwargs): # Check that the resultset only contains one element, avoiding sending a COUNT query try: self[1] - raise self.model.MultipleObjectsReturned('Multiple objects found') + raise self.model.MultipleObjectsReturned("Multiple objects found") except IndexError: pass @@ -778,8 +865,8 @@ def get(self, *args, **kwargs): return obj def _get_ordering_condition(self, colname): - order_type = 'DESC' if colname.startswith('-') else 'ASC' - colname = colname.replace('-', '') + order_type = "DESC" if colname.startswith("-") else "ASC" + colname = colname.replace("-", "") return colname, order_type @@ -821,7 +908,9 @@ class Comment(Model): conditions = [] for colname in colnames: - conditions.append('"{0}" {1}'.format(*self._get_ordering_condition(colname))) + conditions.append( + '"{0}" {1}'.format(*self._get_ordering_condition(colname)) + ) clone = copy.deepcopy(self) clone._order.extend(conditions) @@ -834,7 +923,9 @@ def count(self): *Note: This function executes a SELECT COUNT() and has a performance cost on large datasets* """ if self._batch: - raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") + raise CQLEngineException( + "Only inserts, updates, and deletes are available in batch mode" + ) if self._count is None: query = self._select_query() @@ -876,7 +967,9 @@ class Automobile(Model): if distinct_fields: clone._distinct_fields = distinct_fields else: - clone._distinct_fields = [x.column_name for x in self.model._partition_keys.values()] + clone._distinct_fields = [ + x.column_name for x in self.model._partition_keys.values() + ] return clone @@ -945,7 +1038,7 @@ def allow_filtering(self): return clone def _only_or_defer(self, action, fields): - if action == 'only' and self._only_fields: + if action == "only" and self._only_fields: raise QueryException("QuerySet already has 'only' fields defined") clone = copy.deepcopy(self) @@ -955,13 +1048,15 @@ def _only_or_defer(self, action, fields): if missing_fields: raise QueryException( "Can't resolve fields {0} in {1}".format( - ', '.join(missing_fields), self.model.__name__)) + ", ".join(missing_fields), self.model.__name__ + ) + ) fields = [self.model._columns[field].db_field_name for field in fields] - if action == 'defer': + if action == "defer": clone._defer_fields.update(fields) - elif action == 'only': + elif action == "only": clone._only_fields = fields else: raise ValueError @@ -969,30 +1064,34 @@ def _only_or_defer(self, action, fields): return clone def only(self, fields): - """ Load only these fields for the returned query """ - return self._only_or_defer('only', fields) + """Load only these fields for the returned query""" + return self._only_or_defer("only", fields) def defer(self, fields): - """ Don't load these fields for the returned query """ - return self._only_or_defer('defer', fields) + """Don't load these fields for the returned query""" + return self._only_or_defer("defer", fields) def create(self, **kwargs): - return self.model(**kwargs) \ - .batch(self._batch) \ - .ttl(self._ttl) \ - .consistency(self._consistency) \ - .if_not_exists(self._if_not_exists) \ - .timestamp(self._timestamp) \ - .if_exists(self._if_exists) \ - .using(connection=self._connection) \ + return ( + self.model(**kwargs) + .batch(self._batch) + .ttl(self._ttl) + .consistency(self._consistency) + .if_not_exists(self._if_not_exists) + .timestamp(self._timestamp) + .if_exists(self._if_exists) + .using(connection=self._connection) .save() + ) def delete(self): """ Deletes the contents of a query """ # validate where clause - partition_keys = set(x.db_field_name for x in self.model._partition_keys.values()) + partition_keys = set( + x.db_field_name for x in self.model._partition_keys.values() + ) if partition_keys - set(c.field for c in self._where): raise QueryException("The partition key must be defined on delete queries") @@ -1001,7 +1100,7 @@ def delete(self): where=self._where, timestamp=self._timestamp, conditionals=self._conditional, - if_exists=self._if_exists + if_exists=self._if_exists, ) self._execute(dq) @@ -1028,12 +1127,15 @@ def using(self, keyspace=None, connection=None): """ if connection and self._batch: - raise CQLEngineException("Cannot specify a connection on model in batch mode.") + raise CQLEngineException( + "Cannot specify a connection on model in batch mode." + ) clone = copy.deepcopy(self) if keyspace: from cassandra.cqlengine.models import _clone_model_class - clone.model = _clone_model_class(self.model, {'__keyspace__': keyspace}) + + clone.model = _clone_model_class(self.model, {"__keyspace__": keyspace}) if connection: clone._connection = connection @@ -1066,37 +1168,55 @@ def _get_result_constructor(self): class ModelQuerySet(AbstractQuerySet): - """ - """ + """ """ + def _validate_select_where(self): - """ Checks that a filterset will not create invalid select statement """ + """Checks that a filterset will not create invalid select statement""" # check that there's either a =, a IN or a CONTAINS (collection) # relationship with a primary key or indexed field. We also allow # custom indexes to be queried with any operator (a difference # between a secondary index) - equal_ops = [self.model._get_column_by_db_name(w.field) \ - for w in self._where if not isinstance(w.value, Token) - and (isinstance(w.operator, EqualsOperator) - or self.model._get_column_by_db_name(w.field).custom_index)] + equal_ops = [ + self.model._get_column_by_db_name(w.field) + for w in self._where + if not isinstance(w.value, Token) + and ( + isinstance(w.operator, EqualsOperator) + or self.model._get_column_by_db_name(w.field).custom_index + ) + ] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) - if not any(w.primary_key or w.has_index for w in equal_ops) and not token_comparison and not self._allow_filtering: + if ( + not any(w.primary_key or w.has_index for w in equal_ops) + and not token_comparison + and not self._allow_filtering + ): raise QueryException( - ('Where clauses require either =, a IN or a CONTAINS ' - '(collection) comparison with either a primary key or ' - 'indexed field. You might want to consider setting ' - 'custom_index on fields that you manage index outside ' - 'cqlengine.')) + ( + "Where clauses require either =, a IN or a CONTAINS " + "(collection) comparison with either a primary key or " + "indexed field. You might want to consider setting " + "custom_index on fields that you manage index outside " + "cqlengine." + ) + ) if not self._allow_filtering: # if the query is not on an indexed field if not any(w.has_index for w in equal_ops): - if not any([w.partition_key for w in equal_ops]) and not token_comparison: + if ( + not any([w.partition_key for w in equal_ops]) + and not token_comparison + ): raise QueryException( - ('Filtering on a clustering key without a partition ' - 'key is not allowed unless allow_filtering() is ' - 'called on the queryset. You might want to consider ' - 'setting custom_index on fields that you manage ' - 'index outside cqlengine.')) + ( + "Filtering on a clustering key without a partition " + "key is not allowed unless allow_filtering() is " + "called on the queryset. You might want to consider " + "setting custom_index on fields that you manage " + "index outside cqlengine." + ) + ) def _select_fields(self): if self._defer_fields or self._only_fields: @@ -1105,27 +1225,37 @@ def _select_fields(self): fields = [f for f in fields if f not in self._defer_fields] # select the partition keys if all model fields are set defer if not fields: - fields = [columns.db_field_name for columns in self.model._partition_keys.values()] + fields = [ + columns.db_field_name + for columns in self.model._partition_keys.values() + ] if self._only_fields: fields = [f for f in fields if f in self._only_fields] if not fields: - raise QueryException('No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( - ','.join(self._only_fields), ','.join(self._defer_fields))) + raise QueryException( + 'No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( + ",".join(self._only_fields), ",".join(self._defer_fields) + ) + ) return fields return super(ModelQuerySet, self)._select_fields() def _get_result_constructor(self): - """ Returns a function that will be used to instantiate query results """ + """Returns a function that will be used to instantiate query results""" if not self._values_list: # we want models return self.model._construct_instance - elif self._flat_values_list: # the user has requested flattened list (1 value per row) + elif ( + self._flat_values_list + ): # the user has requested flattened list (1 value per row) key = self._only_fields[0] return lambda row: row[key] else: return lambda row: [row[f] for f in self._only_fields] def _get_ordering_condition(self, colname): - colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) + colname, order_type = super(ModelQuerySet, self)._get_ordering_condition( + colname + ) column = self.model._columns.get(colname) if column is None: @@ -1134,23 +1264,30 @@ def _get_ordering_condition(self, colname): # validate the column selection if not column.primary_key: raise QueryException( - "Can't order on '{0}', can only order on (clustered) primary keys".format(colname)) + "Can't order on '{0}', can only order on (clustered) primary keys".format( + colname + ) + ) pks = [v for k, v in self.model._columns.items() if v.primary_key] if column == pks[0]: raise QueryException( - "Can't order by the first primary key (partition key), clustering (secondary) keys only") + "Can't order by the first primary key (partition key), clustering (secondary) keys only" + ) return column.db_field_name, order_type def values_list(self, *fields, **kwargs): - """ Instructs the query set to return tuples, not model instance """ - flat = kwargs.pop('flat', False) + """Instructs the query set to return tuples, not model instance""" + flat = kwargs.pop("flat", False) if kwargs: - raise TypeError('Unexpected keyword arguments to values_list: %s' - % (kwargs.keys(),)) + raise TypeError( + "Unexpected keyword arguments to values_list: %s" % (kwargs.keys(),) + ) if flat and len(fields) > 1: - raise TypeError("'flat' is not valid when values_list is called with more than one field.") + raise TypeError( + "'flat' is not valid when values_list is called with more than one field." + ) clone = self.only(fields) clone._values_list = True clone._flat_values_list = flat @@ -1181,7 +1318,9 @@ def if_not_exists(self): If the insertion isn't applied, a LWTException is raised. """ if self.model._has_counter: - raise IfNotExistsWithCounterColumn('if_not_exists cannot be used with tables containing counter columns') + raise IfNotExistsWithCounterColumn( + "if_not_exists cannot be used with tables containing counter columns" + ) clone = copy.deepcopy(self) clone._if_not_exists = True return clone @@ -1193,7 +1332,9 @@ def if_exists(self): If the update or delete isn't applied, a LWTException is raised. """ if self.model._has_counter: - raise IfExistsWithCounterColumn('if_exists cannot be used with tables containing counter columns') + raise IfExistsWithCounterColumn( + "if_exists cannot be used with tables containing counter columns" + ) clone = copy.deepcopy(self) clone._if_exists = True return clone @@ -1293,22 +1434,39 @@ class Row(Model): nulled_columns = set() updated_columns = set() - us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, - timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) + us = UpdateStatement( + self.column_family_name, + where=self._where, + ttl=self._ttl, + timestamp=self._timestamp, + conditionals=self._conditional, + if_exists=self._if_exists, + ) for name, val in values.items(): col_name, col_op = self._parse_filter_arg(name) col = self.model._columns.get(col_name) # check for nonexistant columns if col is None: - raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.model.__name__, col_name)) + raise ValidationError( + "{0}.{1} has no column named: {2}".format( + self.__module__, self.model.__name__, col_name + ) + ) # check for primary key update attempts if col.is_primary_key: - raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(col_name, self.__module__, self.model.__name__)) + raise ValidationError( + "Cannot apply update to primary key '{0}' for {1}.{2}".format( + col_name, self.__module__, self.model.__name__ + ) + ) - if col_op == 'remove' and isinstance(col, columns.Map): + if col_op == "remove" and isinstance(col, columns.Map): if not isinstance(val, set): raise ValidationError( - "Cannot apply update operation '{0}' on column '{1}' with value '{2}'. A set is required.".format(col_op, col_name, val)) + "Cannot apply update operation '{0}' on column '{1}' with value '{2}'. A set is required.".format( + col_op, col_name, val + ) + ) val = {v: None for v in val} else: # we should not provide default values in this use case. @@ -1325,10 +1483,22 @@ class Row(Model): self._execute(us) if nulled_columns: - delete_conditional = [condition for condition in self._conditional - if condition.field not in updated_columns] if self._conditional else None - ds = DeleteStatement(self.column_family_name, fields=nulled_columns, - where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) + delete_conditional = ( + [ + condition + for condition in self._conditional + if condition.field not in updated_columns + ] + if self._conditional + else None + ) + ds = DeleteStatement( + self.column_family_name, + fields=nulled_columns, + where=self._where, + conditionals=delete_conditional, + if_exists=self._if_exists, + ) self._execute(ds) @@ -1340,14 +1510,26 @@ class DMLQuery(object): unlike the read query object, this is mutable """ + _ttl = None _consistency = None _timestamp = None _if_not_exists = False _if_exists = False - def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, - if_not_exists=False, conditional=None, timeout=conn.NOT_SET, if_exists=False): + def __init__( + self, + model, + instance=None, + batch=None, + ttl=None, + consistency=None, + timestamp=None, + if_not_exists=False, + conditional=None, + timeout=conn.NOT_SET, + if_exists=False, + ): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance @@ -1361,25 +1543,40 @@ def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, self._timeout = timeout def _execute(self, statement): - connection = self.instance._get_connection() if self.instance else self.model._get_connection() + connection = ( + self.instance._get_connection() + if self.instance + else self.model._get_connection() + ) if self._batch: if self._batch._connection: - if not self._batch._connection_explicit and connection and \ - connection != self._batch._connection: - raise CQLEngineException('BatchQuery queries must be executed on the same connection') + if ( + not self._batch._connection_explicit + and connection + and connection != self._batch._connection + ): + raise CQLEngineException( + "BatchQuery queries must be executed on the same connection" + ) else: # set the BatchQuery connection from the model self._batch._connection = connection return self._batch.add_query(statement) else: - results = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) + results = _execute_statement( + self.model, + statement, + self._consistency, + self._timeout, + connection=connection, + ) if self._if_not_exists or self._if_exists or self._conditional: check_applied(results) return results def batch(self, batch_obj): if batch_obj is not None and not isinstance(batch_obj, BatchQuery): - raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + raise CQLEngineException("batch_obj must be a BatchQuery instance or None") self._batch = batch_obj return self @@ -1387,7 +1584,11 @@ def _delete_null_columns(self, conditionals=None): """ executes a delete query to remove columns that have changed to null """ - ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) + ds = DeleteStatement( + self.column_family_name, + conditionals=conditionals, + if_exists=self._if_exists, + ) deleted_fields = False static_only = True for _, v in self.instance._values.items(): @@ -1404,7 +1605,9 @@ def _delete_null_columns(self, conditionals=None): static_only |= col.static if deleted_fields: - keys = self.model._partition_keys if static_only else self.model._primary_keys + keys = ( + self.model._partition_keys if static_only else self.model._primary_keys + ) for name, col in keys.items(): ds.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(ds) @@ -1419,12 +1622,21 @@ def update(self): if self.instance is None: raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model - null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True + null_clustering_key = ( + False if len(self.instance._clustering_keys) == 0 else True + ) static_changed_only = True - statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, - conditionals=self._conditional, if_exists=self._if_exists) + statement = UpdateStatement( + self.column_family_name, + ttl=self._ttl, + timestamp=self._timestamp, + conditionals=self._conditional, + if_exists=self._if_exists, + ) for name, col in self.instance._clustering_keys.items(): - null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + null_clustering_key = null_clustering_key and col._val_is_null( + getattr(self.instance, name, None) + ) updated_columns = set() # get defined fields and their column names @@ -1449,15 +1661,24 @@ def update(self): if statement.assignments: for name, col in self.model._primary_keys.items(): # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error - if (null_clustering_key or static_changed_only) and (not col.partition_key): + if (null_clustering_key or static_changed_only) and ( + not col.partition_key + ): continue statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(statement) if not null_clustering_key: # remove conditions on fields that have been updated - delete_conditionals = [condition for condition in self._conditional - if condition.field not in updated_columns] if self._conditional else None + delete_conditionals = ( + [ + condition + for condition in self._conditional + if condition.field not in updated_columns + ] + if self._conditional + else None + ) self._delete_null_columns(delete_conditionals) def save(self): @@ -1474,14 +1695,26 @@ def save(self): nulled_fields = set() if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter: - warn("'create' and 'save' actions on Counters are deprecated. It will be disallowed in 4.0. " - "Use the 'update' mechanism instead.", DeprecationWarning) + warn( + "'create' and 'save' actions on Counters are deprecated. It will be disallowed in 4.0. " + "Use the 'update' mechanism instead.", + DeprecationWarning, + ) return self.update() else: - insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) - static_save_only = False if len(self.instance._clustering_keys) == 0 else True + insert = InsertStatement( + self.column_family_name, + ttl=self._ttl, + timestamp=self._timestamp, + if_not_exists=self._if_not_exists, + ) + static_save_only = ( + False if len(self.instance._clustering_keys) == 0 else True + ) for name, col in self.instance._clustering_keys.items(): - static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) + static_save_only = static_save_only and col._val_is_null( + getattr(self.instance, name, None) + ) for name, col in self.instance._columns.items(): if static_save_only and not col.static and not col.partition_key: continue @@ -1504,11 +1737,16 @@ def save(self): self._delete_null_columns() def delete(self): - """ Deletes one instance """ + """Deletes one instance""" if self.instance is None: raise CQLEngineException("DML Query instance attribute is None") - ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) + ds = DeleteStatement( + self.column_family_name, + timestamp=self._timestamp, + conditionals=self._conditional, + if_exists=self._if_exists, + ) for name, col in self.model._primary_keys.items(): val = getattr(self.instance, name) if val is None and not col.partition_key: @@ -1519,11 +1757,17 @@ def delete(self): def _execute_statement(model, statement, consistency_level, timeout, connection=None): params = statement.get_context() - s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size) + s = SimpleStatement( + str(statement), + consistency_level=consistency_level, + fetch_size=statement.fetch_size, + ) if model._partition_key_index: key_values = statement.partition_key_values(model._partition_key_index) if not any(v is None for v in key_values): - parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) + parts = model._routing_key_from_values( + key_values, conn.get_cluster(connection).protocol_version + ) s.routing_key = parts s.keyspace = model._get_keyspace() connection = connection or model._get_connection() From 8b1e14637e63245ff4badf596476d4d164b88bd4 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:04:30 +0200 Subject: [PATCH 07/18] refactor: rename __unicode__ to __str__ in models, functions, operators, named Rename the remaining 5 __unicode__ method definitions across cqlengine/models.py (ColumnQueryEvaluator), cqlengine/functions.py (QueryValue, Token), cqlengine/operators.py (BaseQueryOperator), and cqlengine/named.py (NamedColumn) to __str__. This is part of the systematic removal of the Python 2 UnicodeMixin pattern. The __str__ definitions on each class now take precedence over the inherited UnicodeMixin.__str__ lambda, so behavior is unchanged. --- cassandra/cqlengine/functions.py | 24 +- cassandra/cqlengine/models.py | 399 +++++++++++++++++++++---------- cassandra/cqlengine/named.py | 36 ++- cassandra/cqlengine/operators.py | 38 +-- 4 files changed, 332 insertions(+), 165 deletions(-) diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py index 606f5bc330..c408d8096a 100644 --- a/cassandra/cqlengine/functions.py +++ b/cassandra/cqlengine/functions.py @@ -16,22 +16,24 @@ from cassandra.cqlengine import UnicodeMixin, ValidationError + def get_total_seconds(td): return td.total_seconds() + class QueryValue(UnicodeMixin): """ Base class for query filter values. Subclasses of these classes can be passed into .filter() keyword args """ - format_string = '%({0})s' + format_string = "%({0})s" def __init__(self, value): self.value = value self.context_id = None - def __unicode__(self): + def __str__(self): return self.format_string.format(self.context_id) def set_context_id(self, ctx_id): @@ -50,18 +52,18 @@ class BaseQueryFunction(QueryValue): be passed into .filter() and will be translated into CQL functions in the resulting query """ + pass class TimeUUIDQueryFunction(BaseQueryFunction): - def __init__(self, value): """ :param value: the time to create bounding time uuid from :type value: datetime """ if not isinstance(value, datetime): - raise ValidationError('datetime instance is required') + raise ValidationError("datetime instance is required") super(TimeUUIDQueryFunction, self).__init__(value) def to_database(self, val): @@ -79,7 +81,8 @@ class MinTimeUUID(TimeUUIDQueryFunction): http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ - format_string = 'MinTimeUUID(%({0})s)' + + format_string = "MinTimeUUID(%({0})s)" class MaxTimeUUID(TimeUUIDQueryFunction): @@ -88,7 +91,8 @@ class MaxTimeUUID(TimeUUIDQueryFunction): http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ - format_string = 'MaxTimeUUID(%({0})s)' + + format_string = "MaxTimeUUID(%({0})s)" class Token(BaseQueryFunction): @@ -97,6 +101,7 @@ class Token(BaseQueryFunction): http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun """ + def __init__(self, *values): if len(values) == 1 and isinstance(values[0], (list, tuple)): values = values[0] @@ -109,8 +114,11 @@ def set_columns(self, columns): def get_context_size(self): return len(self.value) - def __unicode__(self): - token_args = ', '.join('%({0})s'.format(self.context_id + i) for i in range(self.get_context_size())) + def __str__(self): + token_args = ", ".join( + "%({0})s".format(self.context_id + i) + for i in range(self.get_context_size()) + ) return "token({0})".format(token_args) def update_context(self, ctx): diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index bc00001666..c12ccdf29c 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -21,7 +21,9 @@ from cassandra.cqlengine import connection from cassandra.cqlengine import query from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist -from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned +from cassandra.cqlengine.query import ( + MultipleObjectsReturned as _MultipleObjectsReturned, +) from cassandra.metadata import protect_name from cassandra.util import OrderedDict @@ -54,6 +56,7 @@ class PolymorphicModelException(ModelException): class UndefinedKeyspaceWarning(Warning): pass + DEFAULT_KEYSPACE = None @@ -62,6 +65,7 @@ class hybrid_classmethod(object): Allows a method to behave as both a class method and normal instance method depending on how it's called """ + def __init__(self, clsmethod, instmethod): self.clsmethod = clsmethod self.instmethod = instmethod @@ -86,9 +90,9 @@ class QuerySetDescriptor(object): """ def __get__(self, obj, model): - """ :rtype: ModelQuerySet """ + """:rtype: ModelQuerySet""" if model.__abstract__: - raise CQLEngineException('cannot execute queries against abstract models') + raise CQLEngineException("cannot execute queries against abstract models") queryset = model.__queryset__(model) # if this is a concrete polymorphic model, and the discriminator @@ -115,13 +119,17 @@ class ConditionalDescriptor(object): """ returns a query set descriptor """ + def __get__(self, instance, model): if instance: + def conditional_setter(*prepared_conditional, **unprepared_conditionals): if len(prepared_conditional) > 0: conditionals = prepared_conditional[0] else: - conditionals = instance.objects.iff(**unprepared_conditionals)._conditional + conditionals = instance.objects.iff( + **unprepared_conditionals + )._conditional instance._conditional = conditionals return instance @@ -132,6 +140,7 @@ def conditional_setter(**unprepared_conditionals): conditionals = model.objects.iff(**unprepared_conditionals)._conditional qs._conditional = conditionals return qs + return conditional_setter def __call__(self, *args, **kwargs): @@ -142,6 +151,7 @@ class TTLDescriptor(object): """ returns a query set descriptor """ + def __get__(self, instance, model): if instance: # instance = copy.deepcopy(instance) @@ -149,6 +159,7 @@ def __get__(self, instance, model): def ttl_setter(ts): instance._ttl = ts return instance + return ttl_setter qs = model.__queryset__(model) @@ -167,12 +178,14 @@ class TimestampDescriptor(object): """ returns a query set descriptor with a timestamp specified """ + def __get__(self, instance, model): if instance: # instance method def timestamp_setter(ts): instance._timestamp = ts return instance + return timestamp_setter return model.objects.timestamp @@ -185,12 +198,14 @@ class IfNotExistsDescriptor(object): """ return a query set descriptor with a if_not_exists flag specified """ + def __get__(self, instance, model): if instance: # instance method def ifnotexists_setter(ife=True): instance._if_not_exists = ife return instance + return ifnotexists_setter return model.objects.if_not_exists @@ -203,12 +218,14 @@ class IfExistsDescriptor(object): """ return a query set descriptor with a if_exists flag specified """ + def __get__(self, instance, model): if instance: # instance method def ifexists_setter(ife=True): instance._if_exists = ife return instance + return ifexists_setter return model.objects.if_exists @@ -221,12 +238,14 @@ class ConsistencyDescriptor(object): """ returns a query set descriptor if called on Class, instance if it was an instance call """ + def __get__(self, instance, model): if instance: # instance = copy.deepcopy(instance) def consistency_setter(consistency): instance.__consistency__ = consistency return instance + return consistency_setter qs = model.__queryset__(model) @@ -245,6 +264,7 @@ class UsingDescriptor(object): """ return a query set descriptor with a connection context specified """ + def __get__(self, instance, model): if instance: # instance method @@ -252,6 +272,7 @@ def using_setter(connection=None): if connection: instance._connection = connection return instance + return using_setter return model.objects.using @@ -272,7 +293,7 @@ class ColumnQueryEvaluator(query.AbstractQueryableColumn): def __init__(self, column): self.column = column - def __unicode__(self): + def __str__(self): return self.column.db_field_name def _get_column(self): @@ -316,7 +337,7 @@ def __set__(self, instance, value): if instance: return instance._values[self.column.column_name].setval(value) else: - raise AttributeError('cannot reassign column values') + raise AttributeError("cannot reassign column values") def __delete__(self, instance): """ @@ -326,7 +347,9 @@ def __delete__(self, instance): if self.column.can_delete: instance._values[self.column.column_name].delval() else: - raise AttributeError('cannot delete {0} columns'.format(self.column.column_name)) + raise AttributeError( + "cannot delete {0} columns".format(self.column.column_name) + ) class BaseModel(object): @@ -376,9 +399,13 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): __consistency__ = None # can be set per query - _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + _timestamp = ( + None # optional timestamp to include with the operation (USING TIMESTAMP) + ) - _if_not_exists = False # optional if_not_exists flag to check existence before insertion + _if_not_exists = ( + False # optional if_not_exists flag to check existence before insertion + ) _if_exists = False # optional if_exists flag to check existence before update @@ -409,17 +436,25 @@ def __init__(self, **values): self._values[name] = value_mngr def __repr__(self): - return '{0}({1})'.format(self.__class__.__name__, - ', '.join('{0}={1!r}'.format(k, getattr(self, k)) - for k in self._defined_columns.keys() - if k != self._discriminator_column_name)) + return "{0}({1})".format( + self.__class__.__name__, + ", ".join( + "{0}={1!r}".format(k, getattr(self, k)) + for k in self._defined_columns.keys() + if k != self._discriminator_column_name + ), + ) def __str__(self): """ Pretty printing of models by their primary key """ - return '{0} <{1}>'.format(self.__class__.__name__, - ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + return "{0} <{1}>".format( + self.__class__.__name__, + ", ".join( + "{0}={1}".format(k, getattr(self, k)) for k in self._primary_keys.keys() + ), + ) @classmethod def _routing_key_from_values(cls, pk_values, protocol_version): @@ -428,19 +463,27 @@ def _routing_key_from_values(cls, pk_values, protocol_version): @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: - raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') + raise ModelException( + "_discover_polymorphic_submodels can only be called on polymorphic base classes" + ) def _discover(klass): - if not klass._is_polymorphic_base and klass.__discriminator_value__ is not None: + if ( + not klass._is_polymorphic_base + and klass.__discriminator_value__ is not None + ): cls._discriminator_map[klass.__discriminator_value__] = klass for subklass in klass.__subclasses__(): _discover(subklass) + _discover(cls) @classmethod def _get_model_by_discriminator_value(cls, key): if not cls._is_polymorphic_base: - raise ModelException('_get_model_by_discriminator_value can only be called on polymorphic base classes') + raise ModelException( + "_get_model_by_discriminator_value can only be called on polymorphic base classes" + ) return cls._discriminator_map.get(key) @classmethod @@ -459,7 +502,9 @@ def _construct_instance(cls, values): disc_key = values.get(cls._discriminator_column_name) if disc_key is None: - raise PolymorphicModelException('discriminator value was not found in values') + raise PolymorphicModelException( + "discriminator value was not found in values" + ) poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base @@ -469,15 +514,19 @@ def _construct_instance(cls, values): klass = poly_base._get_model_by_discriminator_value(disc_key) if klass is None: raise PolymorphicModelException( - 'unrecognized discriminator column {0} for class {1}'.format(disc_key, poly_base.__name__) + "unrecognized discriminator column {0} for class {1}".format( + disc_key, poly_base.__name__ + ) ) if not issubclass(klass, cls): raise PolymorphicModelException( - '{0} is not a subclass of {1}'.format(klass.__name__, cls.__name__) + "{0} is not a subclass of {1}".format(klass.__name__, cls.__name__) ) - values = dict((k, v) for k, v in values.items() if k in klass._columns.keys()) + values = dict( + (k, v) for k, v in values.items() if k in klass._columns.keys() + ) else: klass = cls @@ -540,7 +589,9 @@ def __eq__(self, other): if keys != other_keys: return False - return all(getattr(self, key, None) == getattr(other, key, None) for key in other_keys) + return all( + getattr(self, key, None) == getattr(other, key, None) for key in other_keys + ) def __ne__(self, other): return not self.__eq__(other) @@ -555,37 +606,47 @@ def column_family_name(cls, include_keyspace=True): if include_keyspace: keyspace = cls._get_keyspace() if not keyspace: - raise CQLEngineException("Model keyspace is not set and no default is available. Set model keyspace or setup connection before attempting to generate a query.") - return '{0}.{1}'.format(protect_name(keyspace), cf_name) + raise CQLEngineException( + "Model keyspace is not set and no default is available. Set model keyspace or setup connection before attempting to generate a query." + ) + return "{0}.{1}".format(protect_name(keyspace), cf_name) return cf_name - @classmethod def _raw_column_family_name(cls): if not cls._table_name: if cls.__table_name__: if cls.__table_name_case_sensitive__: - warn("Model __table_name_case_sensitive__ will be removed in 4.0.", PendingDeprecationWarning) + warn( + "Model __table_name_case_sensitive__ will be removed in 4.0.", + PendingDeprecationWarning, + ) cls._table_name = cls.__table_name__ else: table_name = cls.__table_name__.lower() if cls.__table_name__ != table_name: - warn(("Model __table_name__ will be case sensitive by default in 4.0. " - "You should fix the __table_name__ value of the '{0}' model.").format(cls.__name__)) + warn( + ( + "Model __table_name__ will be case sensitive by default in 4.0. " + "You should fix the __table_name__ value of the '{0}' model." + ).format(cls.__name__) + ) cls._table_name = table_name else: if cls._is_polymorphic and not cls._is_polymorphic_base: cls._table_name = cls._polymorphic_base._raw_column_family_name() else: - camelcase = re.compile(r'([a-z])([A-Z])') - ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2).lower()), s) + camelcase = re.compile(r"([a-z])([A-Z])") + ccase = lambda s: camelcase.sub( + lambda v: "{0}_{1}".format(v.group(1), v.group(2).lower()), s + ) cf_name = ccase(cls.__name__) # trim to less than 48 characters or cassandra will complain cf_name = cf_name[-48:] cf_name = cf_name.lower() - cf_name = re.sub(r'^_+', '', cf_name) + cf_name = re.sub(r"^_+", "", cf_name) cls._table_name = cf_name return cls._table_name @@ -607,12 +668,12 @@ def validate(self): # Let an instance be used like a dict of its columns keys/values def __iter__(self): - """ Iterate over column ids. """ + """Iterate over column ids.""" for column_id in self._columns.keys(): yield column_id def __getitem__(self, key): - """ Returns column's value. """ + """Returns column's value.""" if not isinstance(key, str): raise TypeError if key not in self._columns.keys(): @@ -620,7 +681,7 @@ def __getitem__(self, key): return getattr(self, key) def __setitem__(self, key, val): - """ Sets a column's value. """ + """Sets a column's value.""" if not isinstance(key, str): raise TypeError if key not in self._columns.keys(): @@ -638,19 +699,19 @@ def __len__(self): return self._len def keys(self): - """ Returns a list of column IDs. """ + """Returns a list of column IDs.""" return [k for k in self] def values(self): - """ Returns list of column values. """ + """Returns list of column values.""" return [self[k] for k in self] def items(self): - """ Returns a list of column ID/value tuples. """ + """Returns a list of column ID/value tuples.""" return [(k, self[k]) for k in self] def _as_dict(self): - """ Returns a map of column names to cleaned values """ + """Returns a map of column names to cleaned values""" values = self._dynamic_columns or {} for name, col in self._columns.items(): values[name] = col.to_database(getattr(self, name, None)) @@ -703,7 +764,7 @@ def timeout(self, timeout): Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete` operations """ - assert self._batch is None, 'Setting both timeout and batch is not supported' + assert self._batch is None, "Setting both timeout and batch is not supported" self._timeout = timeout return self @@ -722,20 +783,25 @@ def save(self): # handle polymorphic models if self._is_polymorphic: if self._is_polymorphic_base: - raise PolymorphicModelException('cannot save polymorphic base model') + raise PolymorphicModelException("cannot save polymorphic base model") else: - setattr(self, self._discriminator_column_name, self.__discriminator_value__) + setattr( + self, self._discriminator_column_name, self.__discriminator_value__ + ) self.validate() - self.__dmlquery__(self.__class__, self, - batch=self._batch, - ttl=self._ttl, - timestamp=self._timestamp, - consistency=self.__consistency__, - if_not_exists=self._if_not_exists, - conditional=self._conditional, - timeout=self._timeout, - if_exists=self._if_exists).save() + self.__dmlquery__( + self.__class__, + self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + if_not_exists=self._if_not_exists, + conditional=self._conditional, + timeout=self._timeout, + if_exists=self._if_exists, + ).save() self._set_persisted() @@ -761,7 +827,9 @@ def update(self, **values): if col is None: raise ValidationError( "{0}.{1} has no column named: {2}".format( - self.__module__, self.__class__.__name__, column_id)) + self.__module__, self.__class__.__name__, column_id + ) + ) # check for primary key update attempts if col.is_primary_key: @@ -769,26 +837,33 @@ def update(self, **values): if v != current_value: raise ValidationError( "Cannot apply update to primary key '{0}' for {1}.{2}".format( - column_id, self.__module__, self.__class__.__name__)) + column_id, self.__module__, self.__class__.__name__ + ) + ) setattr(self, column_id, v) # handle polymorphic models if self._is_polymorphic: if self._is_polymorphic_base: - raise PolymorphicModelException('cannot update polymorphic base model') + raise PolymorphicModelException("cannot update polymorphic base model") else: - setattr(self, self._discriminator_column_name, self.__discriminator_value__) + setattr( + self, self._discriminator_column_name, self.__discriminator_value__ + ) self.validate() - self.__dmlquery__(self.__class__, self, - batch=self._batch, - ttl=self._ttl, - timestamp=self._timestamp, - consistency=self.__consistency__, - conditional=self._conditional, - timeout=self._timeout, - if_exists=self._if_exists).update() + self.__dmlquery__( + self.__class__, + self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__, + conditional=self._conditional, + timeout=self._timeout, + if_exists=self._if_exists, + ).update() self._set_persisted() @@ -800,13 +875,16 @@ def delete(self): """ Deletes the object from the database """ - self.__dmlquery__(self.__class__, self, - batch=self._batch, - timestamp=self._timestamp, - consistency=self.__consistency__, - timeout=self._timeout, - conditional=self._conditional, - if_exists=self._if_exists).delete() + self.__dmlquery__( + self.__class__, + self, + batch=self._batch, + timestamp=self._timestamp, + consistency=self.__consistency__, + timeout=self._timeout, + conditional=self._conditional, + if_exists=self._if_exists, + ).delete() def get_changed_columns(self): """ @@ -819,9 +897,13 @@ def _class_batch(cls, batch): return cls.objects.batch(batch) def _inst_batch(self, batch): - assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' + assert self._timeout is connection.NOT_SET, ( + "Setting both timeout and batch is not supported" + ) if self._connection: - raise CQLEngineException("Cannot specify a connection on model in batch mode.") + raise CQLEngineException( + "Cannot specify a connection on model in batch mode." + ) self._batch = batch return self @@ -838,7 +920,6 @@ def _inst_get_connection(self): class ModelMetaClass(type): - def __new__(cls, name, bases, attrs): # move column definitions into columns dict # and set default column names @@ -849,48 +930,68 @@ def __new__(cls, name, bases, attrs): # get inherited properties inherited_columns = OrderedDict() for base in bases: - for k, v in getattr(base, '_defined_columns', {}).items(): + for k, v in getattr(base, "_defined_columns", {}).items(): inherited_columns.setdefault(k, v) # short circuit __abstract__ inheritance - is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) + is_abstract = attrs["__abstract__"] = attrs.get("__abstract__", False) # short circuit __discriminator_value__ inheritance - attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') + attrs["__discriminator_value__"] = attrs.get("__discriminator_value__") # TODO __default__ttl__ should be removed in the next major release - options = attrs.get('__options__') or {} - attrs['__default_ttl__'] = options.get('default_time_to_live') + options = attrs.get("__options__") or {} + attrs["__default_ttl__"] = options.get("default_time_to_live") - column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + column_definitions = [ + (k, v) for k, v in attrs.items() if isinstance(v, columns.Column) + ] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) - is_polymorphic_base = any([c[1].discriminator_column for c in column_definitions]) + is_polymorphic_base = any( + [c[1].discriminator_column for c in column_definitions] + ) column_definitions = [x for x in inherited_columns.items()] + column_definitions - discriminator_columns = [c for c in column_definitions if c[1].discriminator_column] + discriminator_columns = [ + c for c in column_definitions if c[1].discriminator_column + ] is_polymorphic = len(discriminator_columns) > 0 if len(discriminator_columns) > 1: - raise ModelDefinitionException('only one discriminator_column can be defined in a model, {0} found'.format(len(discriminator_columns))) + raise ModelDefinitionException( + "only one discriminator_column can be defined in a model, {0} found".format( + len(discriminator_columns) + ) + ) - if attrs['__discriminator_value__'] and not is_polymorphic: - raise ModelDefinitionException('__discriminator_value__ specified, but no base columns defined with discriminator_column=True') + if attrs["__discriminator_value__"] and not is_polymorphic: + raise ModelDefinitionException( + "__discriminator_value__ specified, but no base columns defined with discriminator_column=True" + ) - discriminator_column_name, discriminator_column = discriminator_columns[0] if discriminator_columns else (None, None) + discriminator_column_name, discriminator_column = ( + discriminator_columns[0] if discriminator_columns else (None, None) + ) - if isinstance(discriminator_column, (columns.BaseContainerColumn, columns.Counter)): - raise ModelDefinitionException('counter and container columns cannot be used as discriminator columns') + if isinstance( + discriminator_column, (columns.BaseContainerColumn, columns.Counter) + ): + raise ModelDefinitionException( + "counter and container columns cannot be used as discriminator columns" + ) # find polymorphic base class polymorphic_base = None if is_polymorphic and not is_polymorphic_base: + def _get_polymorphic_base(bases): for base in bases: - if getattr(base, '_is_polymorphic_base', False): + if getattr(base, "_is_polymorphic_base", False): return base klass = _get_polymorphic_base(base.__bases__) if klass: return klass + polymorphic_base = _get_polymorphic_base(bases) defined_columns = OrderedDict(column_definitions) @@ -899,10 +1000,16 @@ def _get_polymorphic_base(bases): if not is_abstract and not any([v.primary_key for k, v in column_definitions]): raise ModelDefinitionException("At least 1 primary key is required.") - counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] - data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)] + counter_columns = [ + c for c in defined_columns.values() if isinstance(c, columns.Counter) + ] + data_columns = [ + c + for c in defined_columns.values() + if not c.primary_key and not isinstance(c, columns.Counter) + ] if counter_columns and data_columns: - raise ModelDefinitionException('counter models may not have data columns') + raise ModelDefinitionException("counter models may not have data columns") has_partition_keys = any(v.partition_key for (k, v) in column_definitions) @@ -919,11 +1026,15 @@ def _transform_column(col_name, col_obj): for k, v in column_definitions: # don't allow a column with the same name as a built-in attribute or method if k in BaseModel.__dict__: - raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k)) + raise ModelDefinitionException( + "column '{0}' conflicts with built-in attribute/method".format(k) + ) # counter column primary keys are not allowed if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter): - raise ModelDefinitionException('counter columns cannot be used as primary keys') + raise ModelDefinitionException( + "counter columns cannot be used as primary keys" + ) # this will mark the first primary key column as a partition # key, if one hasn't been set already @@ -941,14 +1052,24 @@ def _transform_column(col_name, col_obj): v._partition_key_index = overriding._partition_key_index _transform_column(k, v) - partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) - clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + partition_keys = OrderedDict( + k for k in primary_keys.items() if k[1].partition_key + ) + clustering_keys = OrderedDict( + k for k in primary_keys.items() if not k[1].partition_key + ) - if attrs.get('__compute_routing_key__', True): + if attrs.get("__compute_routing_key__", True): key_cols = [c for c in partition_keys.values()] - partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) + partition_key_index = dict( + (col.db_field_name, col._partition_key_index) for col in key_cols + ) key_cql_types = [c.cql_type for c in key_cols] - key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + key_serializer = staticmethod( + lambda parts, proto_version: [ + t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts) + ] + ) else: partition_key_index = {} key_serializer = staticmethod(lambda parts, proto_version: None) @@ -959,23 +1080,37 @@ def _transform_column(col_name, col_obj): raise ModelException("at least one partition key must be defined") if len(partition_keys) == 1: pk_name = [x for x in partition_keys.keys()][0] - attrs['pk'] = attrs[pk_name] + attrs["pk"] = attrs[pk_name] else: # composite partition key case, get/set a tuple of values - _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) - _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) - attrs['pk'] = property(_get, _set) + _get = lambda self: tuple( + self._values[c].getval() for c in partition_keys.keys() + ) + _set = lambda self, val: tuple( + self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val) + ) + attrs["pk"] = property(_get, _set) # some validation col_names = set() for v in column_dict.values(): # check for duplicate column names if v.db_field_name in col_names: - raise ModelException("{0} defines the column '{1}' more than once".format(name, v.db_field_name)) + raise ModelException( + "{0} defines the column '{1}' more than once".format( + name, v.db_field_name + ) + ) if v.clustering_order and not (v.primary_key and not v.partition_key): - raise ModelException("clustering_order may be specified only for clustering primary keys") - if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): - raise ModelException("invalid clustering order '{0}' for column '{1}'".format(repr(v.clustering_order), v.db_field_name)) + raise ModelException( + "clustering_order may be specified only for clustering primary keys" + ) + if v.clustering_order and v.clustering_order.lower() not in ("asc", "desc"): + raise ModelException( + "invalid clustering order '{0}' for column '{1}'".format( + repr(v.clustering_order), v.db_field_name + ) + ) col_names.add(v.db_field_name) # create db_name -> model name map for loading @@ -986,47 +1121,53 @@ def _transform_column(col_name, col_obj): db_map[db_field] = col_name # add management members to the class - attrs['_columns'] = column_dict - attrs['_primary_keys'] = primary_keys - attrs['_defined_columns'] = defined_columns + attrs["_columns"] = column_dict + attrs["_primary_keys"] = primary_keys + attrs["_defined_columns"] = defined_columns # maps the database field to the models key - attrs['_db_map'] = db_map - attrs['_pk_name'] = pk_name - attrs['_dynamic_columns'] = {} + attrs["_db_map"] = db_map + attrs["_pk_name"] = pk_name + attrs["_dynamic_columns"] = {} - attrs['_partition_keys'] = partition_keys - attrs['_partition_key_index'] = partition_key_index - attrs['_key_serializer'] = key_serializer - attrs['_clustering_keys'] = clustering_keys - attrs['_has_counter'] = len(counter_columns) > 0 + attrs["_partition_keys"] = partition_keys + attrs["_partition_key_index"] = partition_key_index + attrs["_key_serializer"] = key_serializer + attrs["_clustering_keys"] = clustering_keys + attrs["_has_counter"] = len(counter_columns) > 0 # add polymorphic management attributes - attrs['_is_polymorphic_base'] = is_polymorphic_base - attrs['_is_polymorphic'] = is_polymorphic - attrs['_polymorphic_base'] = polymorphic_base - attrs['_discriminator_column'] = discriminator_column - attrs['_discriminator_column_name'] = discriminator_column_name - attrs['_discriminator_map'] = {} if is_polymorphic_base else None + attrs["_is_polymorphic_base"] = is_polymorphic_base + attrs["_is_polymorphic"] = is_polymorphic + attrs["_polymorphic_base"] = polymorphic_base + attrs["_discriminator_column"] = discriminator_column + attrs["_discriminator_column_name"] = discriminator_column_name + attrs["_discriminator_map"] = {} if is_polymorphic_base else None # setup class exceptions DoesNotExistBase = None for base in bases: - DoesNotExistBase = getattr(base, 'DoesNotExist', None) + DoesNotExistBase = getattr(base, "DoesNotExist", None) if DoesNotExistBase is not None: break - DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) - attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) + DoesNotExistBase = DoesNotExistBase or attrs.pop( + "DoesNotExist", BaseModel.DoesNotExist + ) + attrs["DoesNotExist"] = type("DoesNotExist", (DoesNotExistBase,), {}) MultipleObjectsReturnedBase = None for base in bases: - MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) + MultipleObjectsReturnedBase = getattr(base, "MultipleObjectsReturned", None) if MultipleObjectsReturnedBase is not None: break - MultipleObjectsReturnedBase = MultipleObjectsReturnedBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) - attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) + MultipleObjectsReturnedBase = MultipleObjectsReturnedBase or attrs.pop( + "MultipleObjectsReturned", BaseModel.MultipleObjectsReturned + ) + attrs["MultipleObjectsReturned"] = type( + "MultipleObjectsReturned", (MultipleObjectsReturnedBase,), {} + ) # create the class and add a QuerySet to it klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py index 265d5c91e4..7bfa1ee5cf 100644 --- a/cassandra/cqlengine/named.py +++ b/cassandra/cqlengine/named.py @@ -20,7 +20,9 @@ from cassandra.cqlengine.models import UsingDescriptor, BaseModel from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist -from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned +from cassandra.cqlengine.query import ( + MultipleObjectsReturned as _MultipleObjectsReturned, +) class QuerySetDescriptor(object): @@ -30,9 +32,9 @@ class QuerySetDescriptor(object): """ def __get__(self, obj, model): - """ :rtype: ModelQuerySet """ + """:rtype: ModelQuerySet""" if model.__abstract__: - raise CQLEngineException('cannot execute queries against abstract models') + raise CQLEngineException("cannot execute queries against abstract models") return SimpleQuerySet(obj) def __call__(self, *args, **kwargs): @@ -52,11 +54,11 @@ class NamedColumn(AbstractQueryableColumn): def __init__(self, name): self.name = name - def __unicode__(self): + def __str__(self): return self.name def _get_column(self): - """ :rtype: NamedColumn """ + """:rtype: NamedColumn""" return self @property @@ -113,11 +115,25 @@ def _partition_keys(self): def _get_partition_keys(self): try: - table_meta = get_cluster(self._get_connection()).metadata.keyspaces[self.keyspace].tables[self.name] - self.__partition_keys = OrderedDict((pk.name, Column(primary_key=True, partition_key=True, db_field=pk.name)) for pk in table_meta.partition_key) + table_meta = ( + get_cluster(self._get_connection()) + .metadata.keyspaces[self.keyspace] + .tables[self.name] + ) + self.__partition_keys = OrderedDict( + ( + pk.name, + Column(primary_key=True, partition_key=True, db_field=pk.name), + ) + for pk in table_meta.partition_key + ) except Exception as e: - raise CQLEngineException("Failed inspecting partition keys for {0}." - "Ensure cqlengine is connected before attempting this with NamedTable.".format(self.column_family_name())) + raise CQLEngineException( + "Failed inspecting partition keys for {0}." + "Ensure cqlengine is connected before attempting this with NamedTable.".format( + self.column_family_name() + ) + ) def column(self, name): return NamedColumn(name) @@ -128,7 +144,7 @@ def column_family_name(self, include_keyspace=True): otherwise, it creates it from the module and class name """ if include_keyspace: - return '{0}.{1}'.format(self.keyspace, self.name) + return "{0}.{1}".format(self.keyspace, self.name) else: return self.name diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py index 2adf51758d..4a59373b5b 100644 --- a/cassandra/cqlengine/operators.py +++ b/cassandra/cqlengine/operators.py @@ -27,16 +27,15 @@ class BaseQueryOperator(UnicodeMixin): # The comparator symbol this operator uses in cql cql_symbol = None - def __unicode__(self): + def __str__(self): if self.cql_symbol is None: raise QueryOperatorException("cql symbol is None") return self.cql_symbol class OpMapMeta(type): - def __init__(cls, name, bases, dct): - if not hasattr(cls, 'opmap'): + if not hasattr(cls, "opmap"): cls.opmap = {} else: cls.opmap[cls.symbol] = cls @@ -44,60 +43,63 @@ def __init__(cls, name, bases, dct): class BaseWhereOperator(BaseQueryOperator, metaclass=OpMapMeta): - """ base operator used for where clauses """ + """base operator used for where clauses""" + @classmethod def get_operator(cls, symbol): try: return cls.opmap[symbol.upper()] except KeyError: - raise QueryOperatorException("{0} doesn't map to a QueryOperator".format(symbol)) + raise QueryOperatorException( + "{0} doesn't map to a QueryOperator".format(symbol) + ) class EqualsOperator(BaseWhereOperator): - symbol = 'EQ' - cql_symbol = '=' + symbol = "EQ" + cql_symbol = "=" class NotEqualsOperator(BaseWhereOperator): - symbol = 'NE' - cql_symbol = '!=' + symbol = "NE" + cql_symbol = "!=" class InOperator(EqualsOperator): - symbol = 'IN' - cql_symbol = 'IN' + symbol = "IN" + cql_symbol = "IN" class GreaterThanOperator(BaseWhereOperator): symbol = "GT" - cql_symbol = '>' + cql_symbol = ">" class GreaterThanOrEqualOperator(BaseWhereOperator): symbol = "GTE" - cql_symbol = '>=' + cql_symbol = ">=" class LessThanOperator(BaseWhereOperator): symbol = "LT" - cql_symbol = '<' + cql_symbol = "<" class LessThanOrEqualOperator(BaseWhereOperator): symbol = "LTE" - cql_symbol = '<=' + cql_symbol = "<=" class ContainsOperator(EqualsOperator): symbol = "CONTAINS" - cql_symbol = 'CONTAINS' + cql_symbol = "CONTAINS" class LikeOperator(EqualsOperator): symbol = "LIKE" - cql_symbol = 'LIKE' + cql_symbol = "LIKE" class IsNotNullOperator(EqualsOperator): symbol = "IS NOT NULL" - cql_symbol = 'IS NOT NULL' + cql_symbol = "IS NOT NULL" From da860871bbb1cde9a4c6bb52adaed1a27758448a Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:06:24 +0200 Subject: [PATCH 08/18] remove: delete UnicodeMixin class and all references UnicodeMixin was a Python 2/3 compatibility shim that wired __str__ to call __unicode__(). Now that all __unicode__ methods have been renamed to __str__ in prior commits, UnicodeMixin serves no purpose. - Delete the UnicodeMixin class from cassandra/cqlengine/__init__.py - Remove UnicodeMixin from the inheritance list of 6 classes: ValueQuoter, BaseClause, BaseCQLStatement, AbstractQueryableColumn, QueryValue, BaseQueryOperator - Remove all 'from cassandra.cqlengine import UnicodeMixin' imports The classes now define __str__ directly, which is the standard Python 3 approach for string representation. --- cassandra/cqlengine/__init__.py | 4 ---- cassandra/cqlengine/functions.py | 4 ++-- cassandra/cqlengine/operators.py | 4 +--- cassandra/cqlengine/query.py | 3 +-- cassandra/cqlengine/statements.py | 7 +++---- 5 files changed, 7 insertions(+), 15 deletions(-) diff --git a/cassandra/cqlengine/__init__.py b/cassandra/cqlengine/__init__.py index b9466e961b..0ee4b5ddd0 100644 --- a/cassandra/cqlengine/__init__.py +++ b/cassandra/cqlengine/__init__.py @@ -25,7 +25,3 @@ class CQLEngineException(Exception): class ValidationError(CQLEngineException): pass - - -class UnicodeMixin(object): - __str__ = lambda x: x.__unicode__() diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py index c408d8096a..37292ff3e2 100644 --- a/cassandra/cqlengine/functions.py +++ b/cassandra/cqlengine/functions.py @@ -14,14 +14,14 @@ from datetime import datetime -from cassandra.cqlengine import UnicodeMixin, ValidationError +from cassandra.cqlengine import ValidationError def get_total_seconds(td): return td.total_seconds() -class QueryValue(UnicodeMixin): +class QueryValue: """ Base class for query filter values. Subclasses of these classes can be passed into .filter() keyword args diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py index 4a59373b5b..a366a3b9a8 100644 --- a/cassandra/cqlengine/operators.py +++ b/cassandra/cqlengine/operators.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cassandra.cqlengine import UnicodeMixin - class QueryOperatorException(Exception): pass -class BaseQueryOperator(UnicodeMixin): +class BaseQueryOperator: # The symbol that identifies this operator in kwargs # ie: colname__ symbol = None diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 0d492bc016..79b69b92be 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -23,7 +23,6 @@ columns, CQLEngineException, ValidationError, - UnicodeMixin, ) from cassandra.cqlengine import connection as conn from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue @@ -98,7 +97,7 @@ def check_applied(result): raise LWTException(result.one()) -class AbstractQueryableColumn(UnicodeMixin): +class AbstractQueryableColumn: """ exposes cql query operators through pythons builtin comparator symbols diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index b1e4e43329..f407932fd8 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -17,7 +17,6 @@ from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine import columns -from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue from cassandra.cqlengine.operators import ( BaseWhereOperator, @@ -31,7 +30,7 @@ class StatementException(Exception): pass -class ValueQuoter(UnicodeMixin): +class ValueQuoter: def __init__(self, value): self.value = value @@ -65,7 +64,7 @@ def __str__(self): return "(" + ", ".join([cql_quote(v) for v in self.value]) + ")" -class BaseClause(UnicodeMixin): +class BaseClause: def __init__(self, field, value): self.field = field self.value = value @@ -537,7 +536,7 @@ def __str__(self): ) -class BaseCQLStatement(UnicodeMixin): +class BaseCQLStatement: """The base cql statement class""" def __init__( From c5efe7cbd16d4545d3623d65e2b352fea77acaff Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:07:43 +0200 Subject: [PATCH 09/18] remove: delete 'from __future__ import absolute_import' from 5 files absolute_import became the default behavior in Python 3.0. These imports were needed in Python 2 to prevent relative imports from shadowing stdlib modules (e.g., 'import io' resolving to a local module instead of the stdlib). Since the driver requires Python 3.9+, these are dead code. Removed from: cassandra/protocol.py, cassandra/cqltypes.py, cassandra/connection.py, cassandra/cluster.py, and tests/integration/cqlengine/query/test_queryset.py. --- cassandra/cluster.py | 2 - cassandra/connection.py | 804 ++++++++++++----- cassandra/cqltypes.py | 1 - cassandra/protocol.py | 568 ++++++++---- .../cqlengine/query/test_queryset.py | 832 ++++++++++++------ 5 files changed, 1532 insertions(+), 675 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 569bb578f1..547290ff0f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -17,8 +17,6 @@ :class:`.Cluster` and :class:`.Session`. """ -from __future__ import absolute_import - import atexit import datetime from binascii import hexlify diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..bfd2ea79e1 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno from functools import wraps, partial, total_ordering @@ -33,20 +32,37 @@ from cassandra.application_info import ApplicationInfoBase from cassandra.protocol_features import ProtocolFeatures -if 'gevent.monkey' in sys.modules: +if "gevent.monkey" in sys.modules: from gevent.queue import Queue, Empty else: from queue import Queue, Empty # noqa -from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion +from cassandra import ( + ConsistencyLevel, + AuthenticationFailed, + OperationTimedOut, + ProtocolVersion, +) from cassandra.marshal import int32_pack -from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, - StartupMessage, ErrorMessage, CredentialsMessage, - QueryMessage, ResultMessage, ProtocolHandler, - InvalidRequestException, SupportedMessage, - AuthResponseMessage, AuthChallengeMessage, - AuthSuccessMessage, ProtocolException, - RegisterMessage, ReviseRequestMessage) +from cassandra.protocol import ( + ReadyMessage, + AuthenticateMessage, + OptionsMessage, + StartupMessage, + ErrorMessage, + CredentialsMessage, + QueryMessage, + ResultMessage, + ProtocolHandler, + InvalidRequestException, + SupportedMessage, + AuthResponseMessage, + AuthChallengeMessage, + AuthSuccessMessage, + ProtocolException, + RegisterMessage, + ReviseRequestMessage, +) from cassandra.segment import SegmentCodec, CrcException from cassandra.util import OrderedDict from cassandra.shard_info import ShardingInfo @@ -64,7 +80,9 @@ try: import lz4 except ImportError: - log.debug("lz4 package could not be imported. LZ4 Compression will not be available") + log.debug( + "lz4 package could not be imported. LZ4 Compression will not be available" + ) pass else: # The compress and decompress functions we need were moved from the lz4 to @@ -79,10 +97,10 @@ lz4_block.decompress except AttributeError: raise ImportError( - 'lz4 not imported correctly. Imported object should have ' - '.compress and and .decompress attributes but does not. ' - 'Please file a bug report on JIRA. (Imported object was ' - '{lz4_block})'.format(lz4_block=repr(lz4_block)) + "lz4 not imported correctly. Imported object should have " + ".compress and and .decompress attributes but does not. " + "Please file a bug report on JIRA. (Imported object was " + "{lz4_block})".format(lz4_block=repr(lz4_block)) ) # Cassandra writes the uncompressed message length in big endian order, @@ -97,25 +115,31 @@ def lz4_decompress(byts): # flip from big-endian to little-endian return lz4_block.decompress(byts[3::-1] + byts[4:]) - locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) + locally_supported_compressions["lz4"] = (lz4_compress, lz4_decompress) segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress) try: import snappy except ImportError: - log.debug("snappy package could not be imported. Snappy Compression will not be available") + log.debug( + "snappy package could not be imported. Snappy Compression will not be available" + ) pass else: # work around apparently buggy snappy decompress def decompress(byts): - if byts == '\x00': - return '' + if byts == "\x00": + return "" return snappy.decompress(byts) - locally_supported_compressions['snappy'] = (snappy.compress, decompress) -DRIVER_NAME, DRIVER_VERSION = 'ScyllaDB Python Driver', sys.modules['cassandra'].__version__ + locally_supported_compressions["snappy"] = (snappy.compress, decompress) -PROTOCOL_VERSION_MASK = 0x7f +DRIVER_NAME, DRIVER_VERSION = ( + "ScyllaDB Python Driver", + sys.modules["cassandra"].__version__, +) + +PROTOCOL_VERSION_MASK = 0x7F HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_TO_CLIENT = 0x80 @@ -125,7 +149,7 @@ def decompress(byts): DEFAULT_LOCAL_PORT_LOW = 49152 DEFAULT_LOCAL_PORT_HIGH = 65535 -frame_header_v3 = struct.Struct('>BhBi') +frame_header_v3 = struct.Struct(">BhBi") class EndPoint(object): @@ -170,7 +194,6 @@ def resolve(self): class EndPointFactory(object): - cluster = None def configure(self, cluster): @@ -209,8 +232,11 @@ def resolve(self): return self._address, self._port def __eq__(self, other): - return isinstance(other, DefaultEndPoint) and \ - self.address == other.address and self.port == other.port + return ( + isinstance(other, DefaultEndPoint) + and self.address == other.address + and self.port == other.port + ) def __hash__(self): return hash((self.address, self.port)) @@ -226,7 +252,6 @@ def __repr__(self): class DefaultEndPointFactory(EndPointFactory): - port = None """ If no port is discovered in the row, this is the default port @@ -239,6 +264,7 @@ def __init__(self, port=None): def create(self, row): # TODO next major... move this class so we don't need this kind of hack from cassandra.metadata import _NodeInfo + addr = _NodeInfo.get_broadcast_rpc_address(row) port = _NodeInfo.get_broadcast_rpc_port(row) if port is None: @@ -246,9 +272,7 @@ def create(self, row): # create the endpoint with the translated address # TODO next major, create a TranslatedEndPoint type - return DefaultEndPoint( - self.cluster.address_translator.translate(addr), - port) + return DefaultEndPoint(self.cluster.address_translator.translate(addr), port) @total_ordering @@ -261,7 +285,7 @@ def __init__(self, proxy_address, server_name, port=9042): self._resolved_address = None # resolved address self._port = port self._server_name = server_name - self._ssl_options = {'server_hostname': server_name} + self._ssl_options = {"server_hostname": server_name} @property def address(self): @@ -277,41 +301,55 @@ def ssl_options(self): def resolve(self): try: - resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, - socket.AF_UNSPEC, socket.SOCK_STREAM) + resolved_addresses = socket.getaddrinfo( + self._proxy_address, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) except socket.gaierror: - log.debug('Could not resolve sni proxy hostname "%s" ' - 'with port %d' % (self._proxy_address, self._port)) + log.debug( + 'Could not resolve sni proxy hostname "%s" ' + "with port %d" % (self._proxy_address, self._port) + ) raise # round-robin pick - self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)] + self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[ + self._index % len(resolved_addresses) + ] self._index += 1 return self._resolved_address, self._port def __eq__(self, other): - return (isinstance(other, SniEndPoint) and - self.address == other.address and self.port == other.port and - self._server_name == other._server_name) + return ( + isinstance(other, SniEndPoint) + and self.address == other.address + and self.port == other.port + and self._server_name == other._server_name + ) def __hash__(self): return hash((self.address, self.port, self._server_name)) def __lt__(self, other): - return ((self.address, self.port, self._server_name) < - (other.address, other.port, self._server_name)) + return (self.address, self.port, self._server_name) < ( + other.address, + other.port, + self._server_name, + ) def __str__(self): return str("%s:%d:%s" % (self.address, self.port, self._server_name)) def __repr__(self): - return "<%s: %s:%d:%s>" % (self.__class__.__name__, - self.address, self.port, self._server_name) + return "<%s: %s:%d:%s>" % ( + self.__class__.__name__, + self.address, + self.port, + self._server_name, + ) class SniEndPointFactory(EndPointFactory): - def __init__(self, proxy_address, port, node_domain=None): self._proxy_address = proxy_address self._port = port @@ -321,7 +359,11 @@ def create(self, row): host_id = row.get("host_id") if host_id is None: raise ValueError("No host_id to create the SniEndPoint") - address = "{}.{}".format(host_id, self._node_domain) if self._node_domain else str(host_id) + address = ( + "{}.{}".format(host_id, self._node_domain) + if self._node_domain + else str(host_id) + ) return SniEndPoint(self._proxy_address, str(address), self._port) def create_from_sni(self, sni): @@ -353,8 +395,10 @@ def resolve(self): return self.address, None def __eq__(self, other): - return (isinstance(other, UnixSocketEndPoint) and - self._unix_socket_path == other._unix_socket_path) + return ( + isinstance(other, UnixSocketEndPoint) + and self._unix_socket_path == other._unix_socket_path + ) def __hash__(self): return hash(self._unix_socket_path) @@ -380,16 +424,25 @@ def __init__(self, version, flags, stream, opcode, body_offset, end_pos): def __eq__(self, other): # facilitates testing if isinstance(other, _Frame): - return (self.version == other.version and - self.flags == other.flags and - self.stream == other.stream and - self.opcode == other.opcode and - self.body_offset == other.body_offset and - self.end_pos == other.end_pos) + return ( + self.version == other.version + and self.flags == other.flags + and self.stream == other.stream + and self.opcode == other.opcode + and self.body_offset == other.body_offset + and self.end_pos == other.end_pos + ) return NotImplemented def __str__(self): - return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) + return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format( + self.version, + self.flags, + self.stream, + self.opcode, + self.body_offset, + self.end_pos - self.body_offset, + ) NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) @@ -414,6 +467,7 @@ class ConnectionShutdown(ConnectionException): """ Raised when a connection has been marked as defunct or has been closed. """ + pass @@ -421,6 +475,7 @@ class ProtocolVersionUnsupported(ConnectionException): """ Server rejected startup message due to unsupported protocol version """ + def __init__(self, endpoint, startup_version): msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version) super(ProtocolVersionUnsupported, self).__init__(msg, endpoint) @@ -432,6 +487,7 @@ class ConnectionBusy(Exception): An attempt was made to send a message through a :class:`.Connection` that was already at the max number of in-flight operations. """ + pass @@ -439,12 +495,14 @@ class ProtocolError(Exception): """ Communication did not match the protocol that this driver expects. """ + pass class CrcMismatchException(ConnectionException): pass + class ContinuousPagingSession(object): def __init__(self, stream_id, decoder, row_factory, connection, state): self.stream_id = stream_id @@ -519,9 +577,15 @@ def maybe_request_more(self): max_queue_size = self._state.max_queue_size num_in_flight = self._state.num_pages_requested - self._state.num_pages_received space_in_queue = max_queue_size - len(self._page_queue) - num_in_flight - log.debug("Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", - self.stream_id, self.connection.host, space_in_queue, self._state.num_pages_requested, - self._state.num_pages_received, num_in_flight) + log.debug( + "Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", + self.stream_id, + self.connection.host, + space_in_queue, + self._state.num_pages_requested, + self._state.num_pages_received, + num_in_flight, + ) if space_in_queue >= max_queue_size / 2: self.update_next_pages(space_in_queue) @@ -529,37 +593,64 @@ def maybe_request_more(self): def update_next_pages(self, num_next_pages): try: self._state.num_pages_requested += num_next_pages - log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host) + log.debug( + "Updating backpressure for session %s from %s", + self.stream_id, + self.connection.host, + ) with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, - self.stream_id, - next_pages=num_next_pages), - self.connection.get_request_id(), - self._on_backpressure_response) + self.connection.send_msg( + ReviseRequestMessage( + ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, + self.stream_id, + next_pages=num_next_pages, + ), + self.connection.get_request_id(), + self._on_backpressure_response, + ) except ConnectionShutdown as ex: - log.debug("Failed to update backpressure for session %s from %s, connection is shutdown", - self.stream_id, self.connection.host) + log.debug( + "Failed to update backpressure for session %s from %s, connection is shutdown", + self.stream_id, + self.connection.host, + ) self.on_error(ex) def _on_backpressure_response(self, response): if isinstance(response, ResultMessage): log.debug("Paging session %s backpressure updated.", self.stream_id) else: - log.error("Failed updating backpressure for session %s from %s: %s", self.stream_id, self.connection.host, - response.to_exception() if hasattr(response, 'to_exception') else response) + log.error( + "Failed updating backpressure for session %s from %s: %s", + self.stream_id, + self.connection.host, + response.to_exception() + if hasattr(response, "to_exception") + else response, + ) self.on_error(response) def cancel(self): try: - log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host) + log.debug( + "Canceling paging session %s from %s", + self.stream_id, + self.connection.host, + ) with self.connection.lock: - self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, - self.stream_id), - self.connection.get_request_id(), - self._on_cancel_response) + self.connection.send_msg( + ReviseRequestMessage( + ReviseRequestMessage.RevisionType.PAGING_CANCEL, self.stream_id + ), + self.connection.get_request_id(), + self._on_cancel_response, + ) except ConnectionShutdown: - log.debug("Failed to cancel session %s from %s, connection is shutdown", - self.stream_id, self.connection.host) + log.debug( + "Failed to cancel session %s from %s, connection is shutdown", + self.stream_id, + self.connection.host, + ) with self._condition: self._stop = True @@ -569,8 +660,14 @@ def _on_cancel_response(self, response): if isinstance(response, ResultMessage): log.debug("Paging session %s canceled.", self.stream_id) else: - log.error("Failed canceling streaming session %s from %s: %s", self.stream_id, self.connection.host, - response.to_exception() if hasattr(response, 'to_exception') else response) + log.error( + "Failed canceling streaming session %s from %s: %s", + self.stream_id, + self.connection.host, + response.to_exception() + if hasattr(response, "to_exception") + else response, + ) self.released = True @@ -582,10 +679,11 @@ def wrapper(self, *args, **kwargs): return f(self, *args, **kwargs) except Exception as exc: self.defunct(exc) + return wrapper -DEFAULT_CQL_VERSION = '3.0.0' +DEFAULT_CQL_VERSION = "3.0.0" class _ConnectionIOBuffer(object): @@ -594,6 +692,7 @@ class _ConnectionIOBuffer(object): protocol V5 and checksumming, the data is read, validated and copied to another cql frame buffer. """ + _io_buffer = None _cql_frame_buffer = None _connection = None @@ -609,8 +708,9 @@ def io_buffer(self): @property def cql_frame_buffer(self): - return self._cql_frame_buffer if self.is_checksumming_enabled else \ - self._io_buffer + return ( + self._cql_frame_buffer if self.is_checksumming_enabled else self._io_buffer + ) def set_checksumming_buffer(self): self.reset_io_buffer() @@ -622,7 +722,7 @@ def is_checksumming_enabled(self): @property def has_consumed_segment(self): - return self._segment_consumed; + return self._segment_consumed def readable_io_bytes(self): return self.io_buffer.tell() @@ -655,20 +755,26 @@ def _align(value: int, total_shards: int): return value + total_shards - shift def generate(self, shard_id: int, total_shards: int): - start = self._align(random.randrange(self.start_port, self.end_port), total_shards) + shard_id + start = ( + self._align(random.randrange(self.start_port, self.end_port), total_shards) + + shard_id + ) beginning = self._align(self.start_port, total_shards) + shard_id - available_ports = itertools.chain(range(start, self.end_port, total_shards), - range(beginning, start, total_shards)) + available_ports = itertools.chain( + range(start, self.end_port, total_shards), + range(beginning, start, total_shards), + ) for port in available_ports: yield port -DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH) +DefaultShardAwarePortGenerator = ShardAwarePortGenerator( + DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH +) class Connection(object): - CALLBACK_ERR_THREAD_THRESHOLD = 100 in_buffer_size = 4096 @@ -698,7 +804,7 @@ class Connection(object): # all request ids to be used in protocol version 3+. Normally concurrency would be controlled # at a higher level by the application or concurrent.execute_concurrent. This attribute # is for lower-level integrations that want some upper bound without reimplementing. - max_in_flight = 2 ** 15 + max_in_flight = 2**15 # A set of available request IDs. When using the v3 protocol or higher, # this will not initially include all request IDs in order to save memory, @@ -721,7 +827,7 @@ class Connection(object): # If the number of orphaned streams reaches this threshold, this connection # will become marked and will be replaced with a new connection by the # owning pool (currently, only HostConnection supports this) - orphaned_threshold = 3 * max_in_flight // 4 + orphaned_threshold = 3 * max_in_flight // 4 is_defunct = False is_closed = False @@ -759,14 +865,32 @@ def _iobuf(self): # backward compatibility, to avoid any change in the reactors return self._io_buffer.io_buffer - def __init__(self, host='127.0.0.1', port=9042, authenticator=None, - ssl_options=None, sockopts=None, compression: Union[bool, str] = True, - cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, - user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, - ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, - on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None): + def __init__( + self, + host="127.0.0.1", + port=9042, + authenticator=None, + ssl_options=None, + sockopts=None, + compression: Union[bool, str] = True, + cql_version=None, + protocol_version=ProtocolVersion.MAX_SUPPORTED, + is_control_connection=False, + user_type_map=None, + connect_timeout=None, + allow_beta_protocol_version=False, + no_compact=False, + ssl_context=None, + owning_pool=None, + shard_id=None, + total_shards=None, + on_orphaned_stream_released=None, + application_info: Optional[ApplicationInfoBase] = None, + ): # TODO next major rename host to endpoint and remove port kwarg. - self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + self.endpoint = ( + host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + ) self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else {} @@ -805,7 +929,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, if not self.ssl_context and self.ssl_options: self.ssl_context = self._build_ssl_context_from_options() - self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) + self.max_request_id = min(self.max_in_flight - 1, (2**15) - 1) # Don't fill the deque with 2**15 items right away. Start with some and add # more if needed. initial_size = min(300, self.max_in_flight) @@ -847,14 +971,14 @@ def create_timer(cls, timeout, callback): raise NotImplementedError() @classmethod - def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs): + def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs): """ A factory function which returns connections which have succeeded in connecting and are ready for service (or raises an exception otherwise). """ start = time.time() - kwargs['connect_timeout'] = timeout + kwargs["connect_timeout"] = timeout conn = cls(endpoint, *args, **kwargs) if host_conn is not None: host_conn._pending_connections.append(conn) @@ -868,32 +992,46 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs): raise conn.last_error elif not conn.connected_event.is_set(): conn.close() - raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) + raise OperationTimedOut( + "Timed out creating connection (%s seconds)" % timeout + ) else: return conn def _build_ssl_context_from_options(self): # Extract a subset of names from self.ssl_options which apply to SSLContext creation - ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers'] - opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options} + ssl_context_opt_names = [ + "ssl_version", + "cert_reqs", + "check_hostname", + "keyfile", + "certfile", + "ca_certs", + "ciphers", + ] + opts = { + k: self.ssl_options.get(k, None) + for k in ssl_context_opt_names + if k in self.ssl_options + } # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always # being explicit - ssl_version = opts.get('ssl_version', None) or ssl.PROTOCOL_TLS_CLIENT - cert_reqs = opts.get('cert_reqs', None) or ssl.CERT_REQUIRED + ssl_version = opts.get("ssl_version", None) or ssl.PROTOCOL_TLS_CLIENT + cert_reqs = opts.get("cert_reqs", None) or ssl.CERT_REQUIRED rv = ssl.SSLContext(protocol=int(ssl_version)) - rv.check_hostname = bool(opts.get('check_hostname', False)) + rv.check_hostname = bool(opts.get("check_hostname", False)) rv.options = int(cert_reqs) - certfile = opts.get('certfile', None) - keyfile = opts.get('keyfile', None) + certfile = opts.get("certfile", None) + keyfile = opts.get("keyfile", None) if certfile: rv.load_cert_chain(certfile, keyfile) - ca_certs = opts.get('ca_certs', None) + ca_certs = opts.get("ca_certs", None) if ca_certs: rv.load_verify_locations(ca_certs) - ciphers = opts.get('ciphers', None) + ciphers = opts.get("ciphers", None) if ciphers: rv.set_ciphers(ciphers) @@ -903,27 +1041,43 @@ def _wrap_socket_from_context(self): # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts # of it that don't involve building an SSLContext under the covers) - wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname'] - opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options} + wrap_socket_opt_names = [ + "server_side", + "do_handshake_on_connect", + "suppress_ragged_eofs", + "server_hostname", + ] + opts = { + k: self.ssl_options.get(k, None) + for k in wrap_socket_opt_names + if k in self.ssl_options + } # PYTHON-1186: set the server_hostname only if the SSLContext has # check_hostname enabled and it is not already provided by the EndPoint ssl options - #opts['server_hostname'] = self.endpoint.address - if (self.ssl_context.check_hostname and 'server_hostname' not in opts): + # opts['server_hostname'] = self.endpoint.address + if self.ssl_context.check_hostname and "server_hostname" not in opts: server_hostname = self.endpoint.address - opts['server_hostname'] = server_hostname + opts["server_hostname"] = server_hostname return self.ssl_context.wrap_socket(self._socket, **opts) def _initiate_connection(self, sockaddr): if self.features.shard_id is not None: - for port in DefaultShardAwarePortGenerator.generate(self.features.shard_id, self.total_shards): + for port in DefaultShardAwarePortGenerator.generate( + self.features.shard_id, self.total_shards + ): try: - self._socket.bind(('', port)) + self._socket.bind(("", port)) break except Exception as ex: log.debug("port=%d couldn't bind cause: %s", port, str(ex)) - log.debug('connection (%r) port=%d should be shard_id=%d', id(self), port, port % self.total_shards) + log.debug( + "connection (%r) port=%d should be shard_id=%d", + id(self), + port, + port % self.total_shards, + ) self._socket.connect(sockaddr) @@ -936,12 +1090,16 @@ def _validate_hostname(self): def _get_socket_addresses(self): address, port = self.endpoint.resolve() - if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX: + if hasattr(socket, "AF_UNIX") and self.endpoint.socket_family == socket.AF_UNIX: return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)] - addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM) + addresses = socket.getaddrinfo( + address, port, self.endpoint.socket_family, socket.SOCK_STREAM + ) if not addresses: - raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,)) + raise ConnectionException( + "getaddrinfo returned empty list for %s" % (self.endpoint,) + ) return addresses @@ -949,7 +1107,7 @@ def _connect_socket(self): sockerr = None addresses = self._get_socket_addresses() port = None - for (af, socktype, proto, _, sockaddr) in addresses: + for af, socktype, proto, _, sockaddr in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) if self.ssl_context: @@ -977,8 +1135,11 @@ def _connect_socket(self): sockerr = err if sockerr: - raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % - ([a[4] for a in addresses], sockerr.strerror or sockerr)) + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" + % ([a[4] for a in addresses], sockerr.strerror or sockerr), + ) if self.sockopts: for args in self.sockopts: @@ -991,7 +1152,9 @@ def _enable_compression(self): def _enable_checksumming(self): self._io_buffer.set_checksumming_buffer() self._is_checksumming_enabled = True - self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression + self._segment_codec = ( + segment_codec_lz4 if self.compressor else segment_codec_no_compression + ) log.debug("Enabling protocol checksumming on connection (%s).", id(self)) def close(self): @@ -1006,11 +1169,16 @@ def defunct(self, exc): exc_info = sys.exc_info() # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message if any(exc_info): - log.debug("Defuncting connection (%s) to %s:", - id(self), self.endpoint, exc_info=exc_info) + log.debug( + "Defuncting connection (%s) to %s:", + id(self), + self.endpoint, + exc_info=exc_info, + ) else: - log.debug("Defuncting connection (%s) to %s: %s", - id(self), self.endpoint, exc) + log.debug( + "Defuncting connection (%s) to %s: %s", id(self), self.endpoint, exc + ) self.last_error = exc self.close() @@ -1038,9 +1206,13 @@ def try_callback(cb): try: cb(new_exc) except Exception: - log.warning("Ignoring unhandled exception while erroring requests for a " - "failed connection (%s) to host %s:", - id(self), self.endpoint, exc_info=True) + log.warning( + "Ignoring unhandled exception while erroring requests for a " + "failed connection (%s) to host %s:", + id(self), + self.endpoint, + exc_info=True, + ) # run first callback from this thread to ensure pool state before leaving cb, _, _ = requests.popitem()[1] @@ -1055,6 +1227,7 @@ def try_callback(cb): def err_all_callbacks(): for cb, _, _ in requests.values(): try_callback(cb) + if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() else: @@ -1085,7 +1258,15 @@ def handle_pushed(self, response): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): + def send_msg( + self, + msg, + request_id, + cb, + encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, + result_metadata=None, + ): if self.is_defunct: msg = "Connection to %s is defunct" % self.endpoint if self.last_error: @@ -1102,8 +1283,13 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, - allow_beta_protocol_version=self.allow_beta_protocol_version) + msg = encoder( + msg, + request_id, + self.protocol_version, + compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version, + ) if self._is_checksumming_enabled: buffer = io.BytesIO() @@ -1130,8 +1316,8 @@ def wait_for_responses(self, *msgs, **kwargs): if self.last_error: msg += ": %s" % (self.last_error,) raise ConnectionShutdown(msg) - timeout = kwargs.get('timeout') - fail_on_error = kwargs.get('fail_on_error', True) + timeout = kwargs.get("timeout") + fail_on_error = kwargs.get("fail_on_error", True) waiter = ResponseWaiter(self, len(msgs), fail_on_error) # busy wait for sufficient space on the connection @@ -1144,9 +1330,11 @@ def wait_for_responses(self, *msgs, **kwargs): self.in_flight += available for i, request_id in enumerate(request_ids): - self.send_msg(msgs[messages_sent + i], - request_id, - partial(waiter.got_response, index=messages_sent + i)) + self.send_msg( + msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i), + ) messages_sent += available if messages_sent == len(msgs): @@ -1172,8 +1360,8 @@ def register_watcher(self, event_type, callback, register_timeout=None): """ self._push_watchers[event_type].add(callback) self.wait_for_response( - RegisterMessage(event_list=[event_type]), - timeout=register_timeout) + RegisterMessage(event_list=[event_type]), timeout=register_timeout + ) def register_watchers(self, type_callback_dict, register_timeout=None): """ @@ -1183,7 +1371,8 @@ def register_watchers(self, type_callback_dict, register_timeout=None): self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=type_callback_dict.keys()), - timeout=register_timeout) + timeout=register_timeout, + ) def control_conn_disposed(self): self.is_control_connection = False @@ -1196,14 +1385,19 @@ def _read_frame_header(self): if pos: version = buf[0] & PROTOCOL_VERSION_MASK if version not in ProtocolVersion.SUPPORTED_VERSIONS: - raise ProtocolError("This version of the driver does not support protocol version %d" % version) + raise ProtocolError( + "This version of the driver does not support protocol version %d" + % version + ) # this frame header struct is everything after the version byte header_size = frame_header_v3.size + 1 if pos >= header_size: flags, stream, op, body_len = frame_header_v3.unpack_from(buf, 1) if body_len < 0: raise ProtocolError("Received negative body length: %r" % body_len) - self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) + self._current_frame = _Frame( + version, flags, stream, op, header_size, body_len + header_size + ) return pos @defunct_on_error @@ -1212,7 +1406,9 @@ def _process_segment_buffer(self): if readable_bytes >= self._segment_codec.header_length_with_crc: try: self._io_buffer.io_buffer.seek(0) - segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer) + segment_header = self._segment_codec.decode_header( + self._io_buffer.io_buffer + ) if readable_bytes >= segment_header.segment_length: segment = self._segment_codec.decode(self._iobuf, segment_header) @@ -1235,7 +1431,10 @@ def process_io_buffer(self): self._process_segment_buffer() self._io_buffer.reset_io_buffer() - if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment: + if ( + self._is_checksumming_enabled + and not self._io_buffer.has_consumed_segment + ): # We couldn't read an entire segment from the io buffer, so return # control to allow more bytes to be read off the wire return @@ -1246,7 +1445,10 @@ def process_io_buffer(self): pos = self._io_buffer.readable_cql_frame_bytes() if not self._current_frame or pos < self._current_frame.end_pos: - if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + if ( + self._is_checksumming_enabled + and self._io_buffer.readable_io_bytes() + ): # We have a multi-segments message and we need to read more # data to complete the current cql frame continue @@ -1258,7 +1460,9 @@ def process_io_buffer(self): else: frame = self._current_frame self._io_buffer.cql_frame_buffer.seek(frame.body_offset) - msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) + msg = self._io_buffer.cql_frame_buffer.read( + frame.end_pos - frame.body_offset + ) self.process_msg(frame, msg) self._io_buffer.reset_cql_frame_buffer() self._current_frame = None @@ -1297,11 +1501,23 @@ def process_msg(self, header, body): return try: - response = decoder(header.version, self.features, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor, result_metadata) + response = decoder( + header.version, + self.features, + self.user_type_map, + stream_id, + header.flags, + header.opcode, + body, + self.decompressor, + result_metadata, + ) except Exception as exc: - log.exception("Error decoding response from Cassandra. " - "%s; buffer: %r", header, self._iobuf.getvalue()) + log.exception( + "Error decoding response from Cassandra. %s; buffer: %r", + header, + self._iobuf.getvalue(), + ) if callback is not None: callback(exc) self.defunct(exc) @@ -1310,10 +1526,14 @@ def process_msg(self, header, body): try: if stream_id >= 0: if isinstance(response, ProtocolException): - if 'unsupported protocol version' in response.message: + if "unsupported protocol version" in response.message: self.is_unsupported_proto_version = True else: - log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) + log.error( + "Closing connection %s due to protocol error: %s", + self, + response.summary_msg(), + ) self.defunct(response) if callback is not None: callback(response) @@ -1347,8 +1567,14 @@ def remove_continuous_paging_session(self, stream_id): @defunct_on_error def _send_options_message(self): - log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint) - self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) + log.debug( + "Sending initial options message for new connection (%s) to %s", + id(self), + self.endpoint, + ) + self.send_msg( + OptionsMessage(), self.get_request_id(), self._handle_options_response + ) @defunct_on_error def _handle_options_response(self, options_response): @@ -1360,17 +1586,23 @@ def _handle_options_response(self, options_response): if isinstance(options_response, ConnectionException): raise options_response else: - log.error("Did not get expected SupportedMessage response; " - "instead, got: %s", options_response) - raise ConnectionException("Did not get expected SupportedMessage " - "response; instead, got: %s" - % (options_response,)) - - log.debug("Received options response on new connection (%s) from %s", - id(self), self.endpoint) + log.error( + "Did not get expected SupportedMessage response; instead, got: %s", + options_response, + ) + raise ConnectionException( + "Did not get expected SupportedMessage " + "response; instead, got: %s" % (options_response,) + ) + + log.debug( + "Received options response on new connection (%s) from %s", + id(self), + self.endpoint, + ) supported_cql_versions = options_response.cql_versions - remote_supported_compressions = options_response.options['COMPRESSION'] - self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0] + remote_supported_compressions = options_response.options["COMPRESSION"] + self._product_type = options_response.options.get("PRODUCT_TYPE", [None])[0] options = {} if self._application_info: @@ -1382,21 +1614,25 @@ def _handle_options_response(self, options_response): raise ProtocolError( "cql_version %r is not supported by remote (w/ native " "protocol). Supported versions: %r" - % (self.cql_version, supported_cql_versions)) + % (self.cql_version, supported_cql_versions) + ) else: self.cql_version = supported_cql_versions[0] self._compressor = None compression_type = None if self.compression: - overlap = (set(locally_supported_compressions.keys()) & - set(remote_supported_compressions)) + overlap = set(locally_supported_compressions.keys()) & set( + remote_supported_compressions + ) if len(overlap) == 0: if locally_supported_compressions: - log.error("No available compression types supported on both ends." - " locally supported: %r. remotely supported: %r", - locally_supported_compressions.keys(), - remote_supported_compressions) + log.error( + "No available compression types supported on both ends." + " locally supported: %r. remotely supported: %r", + locally_supported_compressions.keys(), + remote_supported_compressions, + ) else: compression_type = None if isinstance(self.compression, str): @@ -1404,7 +1640,8 @@ def _handle_options_response(self, options_response): if self.compression not in remote_supported_compressions: raise ProtocolError( "The requested compression type (%s) is not supported by the Cassandra server at %s" - % (self.compression, self.endpoint)) + % (self.compression, self.endpoint) + ) compression_type = self.compression else: # our locally supported compressions are ordered to prefer @@ -1416,30 +1653,42 @@ def _handle_options_response(self, options_response): # If snappy compression is selected with v5+checksumming, the connection # will fail with OTO. Only lz4 is supported - if (compression_type == 'snappy' and - ProtocolVersion.has_checksumming_support(self.protocol_version)): - log.debug("Snappy compression is not supported with protocol version %s and " - "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version) + if ( + compression_type == "snappy" + and ProtocolVersion.has_checksumming_support(self.protocol_version) + ): + log.debug( + "Snappy compression is not supported with protocol version %s and " + "checksumming. Consider installing lz4. Disabling compression.", + self.protocol_version, + ) compression_type = None else: # set the decompressor here, but set the compressor only after # a successful Ready message self._compression_type = compression_type - self._compressor, self.decompressor = \ + self._compressor, self.decompressor = ( locally_supported_compressions[compression_type] + ) - self._send_startup_message(compression_type, no_compact=self.no_compact, extra_options=options) + self._send_startup_message( + compression_type, no_compact=self.no_compact, extra_options=options + ) @defunct_on_error - def _send_startup_message(self, compression=None, no_compact=False, extra_options=None): + def _send_startup_message( + self, compression=None, no_compact=False, extra_options=None + ): log.debug("Sending StartupMessage on %s", self) - opts = {'DRIVER_NAME': DRIVER_NAME, - 'DRIVER_VERSION': DRIVER_VERSION, - **extra_options} + opts = { + "DRIVER_NAME": DRIVER_NAME, + "DRIVER_VERSION": DRIVER_VERSION, + **extra_options, + } if compression: - opts['COMPRESSION'] = compression + opts["COMPRESSION"] = compression if no_compact: - opts['NO_COMPACT'] = 'true' + opts["NO_COMPACT"] = "true" sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) log.debug("Sent StartupMessage on %s", self) @@ -1451,12 +1700,18 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): if isinstance(startup_response, ReadyMessage): if self.authenticator: - log.warning("An authentication challenge was not sent, " - "this is suspicious because the driver expects " - "authentication (configured authenticator = %s)", - self.authenticator.__class__.__name__) - - log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) + log.warning( + "An authentication challenge was not sent, " + "this is suspicious because the driver expects " + "authentication (configured authenticator = %s)", + self.authenticator.__class__.__name__, + ) + + log.debug( + "Got ReadyMessage on new connection (%s) from %s", + id(self), + self.endpoint, + ) self._enable_compression() if ProtocolVersion.has_checksumming_support(self.protocol_version): @@ -1464,14 +1719,21 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): - log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, startup_response.authenticator) + log.debug( + "Got AuthenticateMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + startup_response.authenticator, + ) if self.authenticator is None: - log.error("Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " - "consider using TransitionalModePlainTextAuthProvider " - "if DSE authentication is configured with transitional mode" % (self.host,)) - raise AuthenticationFailed('Remote end requires authentication') + log.error( + "Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " + "consider using TransitionalModePlainTextAuthProvider " + "if DSE authentication is configured with transitional mode" + % (self.host,) + ) + raise AuthenticationFailed("Remote end requires authentication") self._enable_compression() if ProtocolVersion.has_checksumming_support(self.protocol_version): @@ -1484,24 +1746,38 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.send_msg(cm, self.get_request_id(), cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) - self.authenticator.server_authenticator_class = startup_response.authenticator + self.authenticator.server_authenticator_class = ( + startup_response.authenticator + ) initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response - self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), - self._handle_auth_response) + self.send_msg( + AuthResponseMessage(initial_response), + self.get_request_id(), + self._handle_auth_response, + ) elif isinstance(startup_response, ErrorMessage): - log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, startup_response.summary_msg()) + log.debug( + "Received ErrorMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + startup_response.summary_msg(), + ) if did_authenticate: raise AuthenticationFailed( - "Failed to authenticate to %s: %s" % - (self.endpoint, startup_response.summary_msg())) + "Failed to authenticate to %s: %s" + % (self.endpoint, startup_response.summary_msg()) + ) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" - % (self.endpoint, startup_response.summary_msg())) + % (self.endpoint, startup_response.summary_msg()) + ) elif isinstance(startup_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the startup handshake", (self.endpoint)) + log.debug( + "Connection to %s was closed during the startup handshake", + (self.endpoint), + ) raise startup_response else: msg = "Unexpected response during Connection setup: %r" @@ -1525,13 +1801,21 @@ def _handle_auth_response(self, auth_response): log.debug("Responding to auth challenge on %s", self) self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): - log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.endpoint, auth_response.summary_msg()) + log.debug( + "Received ErrorMessage on new connection (%s) from %s: %s", + id(self), + self.endpoint, + auth_response.summary_msg(), + ) raise AuthenticationFailed( - "Failed to authenticate to %s: %s" % - (self.endpoint, auth_response.summary_msg())) + "Failed to authenticate to %s: %s" + % (self.endpoint, auth_response.summary_msg()) + ) elif isinstance(auth_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the authentication process", self.endpoint) + log.debug( + "Connection to %s was closed during the authentication process", + self.endpoint, + ) raise auth_response else: msg = "Unexpected response during Connection authentication to %s: %r" @@ -1542,8 +1826,9 @@ def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: return - query = QueryMessage(query='USE "%s"' % (keyspace,), - consistency_level=ConsistencyLevel.ONE) + query = QueryMessage( + query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE + ) try: result = self.wait_for_response(query) except InvalidRequestException as ire: @@ -1551,7 +1836,8 @@ def set_keyspace_blocking(self, keyspace): raise ire.to_exception() except Exception as exc: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.endpoint) + "Problem while setting keyspace: %r" % (exc,), self.endpoint + ) self.defunct(conn_exc) raise conn_exc @@ -1559,7 +1845,8 @@ def set_keyspace_blocking(self, keyspace): self.keyspace = keyspace else: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.endpoint) + "Problem while setting keyspace: %r" % (result,), self.endpoint + ) self.defunct(conn_exc) raise conn_exc @@ -1596,8 +1883,9 @@ def set_keyspace_async(self, keyspace, callback): callback(self, None) return - query = QueryMessage(query='USE "%s"' % (keyspace,), - consistency_level=ConsistencyLevel.ONE) + query = QueryMessage( + query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE + ) def process_result(result): if isinstance(result, ResultMessage): @@ -1606,8 +1894,15 @@ def process_result(result): elif isinstance(result, InvalidRequestException): callback(self, result.to_exception()) else: - callback(self, self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.endpoint))) + callback( + self, + self.defunct( + ConnectionException( + "Problem while setting keyspace: %r" % (result,), + self.endpoint, + ) + ), + ) # We've incremented self.in_flight above, so we "have permission" to # acquire a new request id @@ -1629,12 +1924,17 @@ def __str__(self): elif self.is_closed: status = " (closed)" - return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status) + return "<%s(%r) %s%s>" % ( + self.__class__.__name__, + id(self), + self.endpoint, + status, + ) + __repr__ = __str__ class ResponseWaiter(object): - def __init__(self, connection, num_responses, fail_on_error): self.connection = connection self.pending = num_responses @@ -1647,7 +1947,7 @@ def got_response(self, response, index): with self.connection.lock: self.connection.in_flight -= 1 if isinstance(response, Exception): - if hasattr(response, 'to_exception'): + if hasattr(response, "to_exception"): response = response.to_exception() if self.fail_on_error: self.error = response @@ -1689,14 +1989,23 @@ def __init__(self, connection, owner): self._event = Event() self.connection = connection self.owner = owner - log.debug("Sending options message heartbeat on idle connection (%s) %s", - id(connection), connection.endpoint) + log.debug( + "Sending options message heartbeat on idle connection (%s) %s", + id(connection), + connection.endpoint, + ) with connection.lock: if connection.in_flight < connection.max_request_id: connection.in_flight += 1 - connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) + connection.send_msg( + OptionsMessage(), + connection.get_request_id(), + self._options_callback, + ) else: - self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") + self._exception = Exception( + "Failed to send heartbeat because connection 'in_flight' exceeds threshold" + ) self._event.set() def wait(self, timeout): @@ -1705,23 +2014,29 @@ def wait(self, timeout): if self._exception: raise self._exception else: - raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint) + raise OperationTimedOut( + "Connection heartbeat timeout after %s seconds" % (timeout,), + self.connection.endpoint, + ) def _options_callback(self, response): if isinstance(response, SupportedMessage): - log.debug("Received options response on connection (%s) from %s", - id(self.connection), self.connection.endpoint) + log.debug( + "Received options response on connection (%s) from %s", + id(self.connection), + self.connection.endpoint, + ) else: if isinstance(response, ConnectionException): self._exception = response else: - self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" - % (response,)) + self._exception = ConnectionException( + "Received unexpected response to OptionsMessage: %s" % (response,) + ) self._event.set() class ConnectionHeartbeat(Thread): - def __init__(self, interval_sec, get_connection_holders, timeout): Thread.__init__(self, name="Connection heartbeat") self._interval = interval_sec @@ -1742,7 +2057,9 @@ def run(self): futures = [] failed_connections = [] try: - for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: + for connections, owner in [ + (o.get_connections(), o) for o in self._get_connection_holders() + ]: for connection in connections: self._raise_if_stopped() if not (connection.is_defunct or connection.is_closed): @@ -1750,14 +2067,20 @@ def run(self): try: futures.append(HeartbeatFuture(connection, owner)) except Exception as e: - log.warning("Failed sending heartbeat message on connection (%s) to %s", - id(connection), connection.endpoint) + log.warning( + "Failed sending heartbeat message on connection (%s) to %s", + id(connection), + connection.endpoint, + ) failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: - log.debug("Cannot send heartbeat message on connection (%s) to %s", - id(connection), connection.endpoint) + log.debug( + "Cannot send heartbeat message on connection (%s) to %s", + id(connection), + connection.endpoint, + ) # make sure the owner sees this defunt/closed connection owner.return_connection(connection) self._raise_if_stopped() @@ -1775,8 +2098,11 @@ def run(self): connection.in_flight -= 1 connection.reset_idle() except Exception as e: - log.warning("Heartbeat failed for connection (%s) to %s", - id(connection), connection.endpoint) + log.warning( + "Heartbeat failed for connection (%s) to %s", + id(connection), + connection.endpoint, + ) failed_connections.append((f.connection, f.owner, e)) timeout = self._timeout - (time.time() - start_time) @@ -1806,7 +2132,6 @@ def _raise_if_stopped(self): class Timer(object): - canceled = False def __init__(self, timeout, callback): @@ -1831,7 +2156,6 @@ def finish(self, time_now): class TimerManager(object): - def __init__(self): self._queue = [] self._new_timers = [] diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index e043a05015..d500885a69 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -27,7 +27,6 @@ # for example), these classes would be a good place to tack on # .from_cql_literal() and .as_cql_literal() classmethods (or whatever). -from __future__ import absolute_import # to enable import io from stdlib import ast from binascii import unhexlify import calendar diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..2276dacbaf 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import # to enable import io from stdlib from collections import namedtuple import logging import socket @@ -22,22 +21,64 @@ from cassandra import OperationType, ProtocolVersion from cassandra import type_codes, DriverException -from cassandra import (Unavailable, WriteTimeout, RateLimitReached, ReadTimeout, - WriteFailure, ReadFailure, FunctionFailure, - AlreadyExists, InvalidRequest, Unauthorized, - UnsupportedOperation, UserFunctionDescriptor, - UserAggregateDescriptor, SchemaTargetType) -from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, - CounterColumnType, DateType, DecimalType, - DoubleType, FloatType, Int32Type, - InetAddressType, IntegerType, ListType, - LongType, MapType, SetType, TimeUUIDType, - UTF8Type, VarcharType, UUIDType, UserType, - TupleType, lookup_casstype, SimpleDateType, - TimeType, ByteType, ShortType, DurationType) -from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, - uint8_pack, int8_unpack, uint64_pack, - v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack) +from cassandra import ( + Unavailable, + WriteTimeout, + RateLimitReached, + ReadTimeout, + WriteFailure, + ReadFailure, + FunctionFailure, + AlreadyExists, + InvalidRequest, + Unauthorized, + UnsupportedOperation, + UserFunctionDescriptor, + UserAggregateDescriptor, + SchemaTargetType, +) +from cassandra.cqltypes import ( + AsciiType, + BytesType, + BooleanType, + CounterColumnType, + DateType, + DecimalType, + DoubleType, + FloatType, + Int32Type, + InetAddressType, + IntegerType, + ListType, + LongType, + MapType, + SetType, + TimeUUIDType, + UTF8Type, + VarcharType, + UUIDType, + UserType, + TupleType, + lookup_casstype, + SimpleDateType, + TimeType, + ByteType, + ShortType, + DurationType, +) +from cassandra.marshal import ( + int32_pack, + int32_unpack, + uint16_pack, + uint16_unpack, + uint8_pack, + int8_unpack, + uint64_pack, + v3_header_pack, + uint32_pack, + uint32_le_unpack, + uint32_le_pack, +) from cassandra.policies import ColDesc from cassandra import WriteType from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY @@ -53,7 +94,10 @@ class NotSupportedError(Exception): class InternalError(Exception): pass -ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) + +ColumnMetadata = namedtuple( + "ColumnMetadata", ["keyspace_name", "table_name", "name", "type"] +) HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 @@ -80,12 +124,11 @@ def get_registered_classes(): class _RegisterMessageType(type): def __init__(cls, name, bases, dct): - if not name.startswith('_'): + if not name.startswith("_"): register_class(cls) class _MessageType(object, metaclass=_RegisterMessageType): - tracing = False custom_payload = None warnings = None @@ -96,17 +139,23 @@ def update_custom_payload(self, other): self.custom_payload = {} self.custom_payload.update(other) if len(self.custom_payload) > 65535: - raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)") + raise ValueError( + "Custom payload map exceeds max count allowed by protocol (65535)" + ) def __repr__(self): - return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) + return "<%s(%s)>" % ( + self.__class__.__name__, + ", ".join("%s=%r" % i for i in _get_params(self)), + ) def _get_params(message_obj): base_attrs = dir(_MessageType) return ( - (n, a) for n, a in message_obj.__dict__.items() - if n not in base_attrs and not n.startswith('_') and not callable(a) + (n, a) + for n, a in message_obj.__dict__.items() + if n not in base_attrs and not n.startswith("_") and not callable(a) ) @@ -115,8 +164,8 @@ def _get_params(message_obj): class ErrorMessage(_MessageType, Exception): opcode = 0x00 - name = 'ERROR' - summary = 'Unknown' + name = "ERROR" + summary = "Unknown" def __init__(self, code, message, info): self.code = code @@ -135,12 +184,16 @@ def recv_body(cls, f, protocol_version, protocol_features, *args): return subcls(code=code, message=msg, info=extra_info) def summary_msg(self): - msg = 'Error from server: code=%04x [%s] message="%s"' \ - % (self.code, self.summary, self.message) + msg = 'Error from server: code=%04x [%s] message="%s"' % ( + self.code, + self.summary, + self.message, + ) return msg def __str__(self): - return '<%s>' % self.summary_msg() + return "<%s>" % self.summary_msg() + __repr__ = __str__ @staticmethod @@ -170,34 +223,34 @@ class RequestValidationException(ErrorMessageSub): class ServerError(ErrorMessageSub): - summary = 'Server error' + summary = "Server error" error_code = 0x0000 class ProtocolException(ErrorMessageSub): - summary = 'Protocol error' + summary = "Protocol error" error_code = 0x000A @property def is_beta_protocol_error(self): - return 'USE_BETA flag is unset' in str(self) + return "USE_BETA flag is unset" in str(self) class BadCredentials(ErrorMessageSub): - summary = 'Bad credentials' + summary = "Bad credentials" error_code = 0x0100 class UnavailableErrorMessage(RequestExecutionException): - summary = 'Unavailable exception' + summary = "Unavailable exception" error_code = 0x1000 @staticmethod def recv_error_info(f, protocol_version): return { - 'consistency': read_consistency_level(f), - 'required_replicas': read_int(f), - 'alive_replicas': read_int(f), + "consistency": read_consistency_level(f), + "required_replicas": read_int(f), + "alive_replicas": read_int(f), } def to_exception(self): @@ -205,17 +258,17 @@ def to_exception(self): class OverloadedErrorMessage(RequestExecutionException): - summary = 'Coordinator node overloaded' + summary = "Coordinator node overloaded" error_code = 0x1001 class IsBootstrappingErrorMessage(RequestExecutionException): - summary = 'Coordinator node is bootstrapping' + summary = "Coordinator node is bootstrapping" error_code = 0x1002 class TruncateError(RequestExecutionException): - summary = 'Error during truncate' + summary = "Error during truncate" error_code = 0x1003 @@ -226,10 +279,10 @@ class WriteTimeoutErrorMessage(RequestExecutionException): @staticmethod def recv_error_info(f, protocol_version): return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'write_type': WriteType.name_to_value[read_string(f)], + "consistency": read_consistency_level(f), + "received_responses": read_int(f), + "required_responses": read_int(f), + "write_type": WriteType.name_to_value[read_string(f)], } def to_exception(self): @@ -243,10 +296,10 @@ class ReadTimeoutErrorMessage(RequestExecutionException): @staticmethod def recv_error_info(f, protocol_version): return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'data_retrieved': bool(read_byte(f)), + "consistency": read_consistency_level(f), + "received_responses": read_int(f), + "required_responses": read_int(f), + "data_retrieved": bool(read_byte(f)), } def to_exception(self): @@ -273,12 +326,12 @@ def recv_error_info(f, protocol_version): data_retrieved = bool(read_byte(f)) return { - 'consistency': consistency, - 'received_responses': received_responses, - 'required_responses': required_responses, - 'failures': failures, - 'error_code_map': error_code_map, - 'data_retrieved': data_retrieved + "consistency": consistency, + "received_responses": received_responses, + "required_responses": required_responses, + "failures": failures, + "error_code_map": error_code_map, + "data_retrieved": data_retrieved, } def to_exception(self): @@ -292,9 +345,9 @@ class FunctionFailureMessage(RequestExecutionException): @staticmethod def recv_error_info(f, protocol_version): return { - 'keyspace': read_string(f), - 'function': read_string(f), - 'arg_types': [read_string(f) for _ in range(read_short(f))], + "keyspace": read_string(f), + "function": read_string(f), + "arg_types": [read_string(f) for _ in range(read_short(f))], } def to_exception(self): @@ -321,12 +374,12 @@ def recv_error_info(f, protocol_version): write_type = WriteType.name_to_value[read_string(f)] return { - 'consistency': consistency, - 'received_responses': received_responses, - 'required_responses': required_responses, - 'failures': failures, - 'error_code_map': error_code_map, - 'write_type': write_type + "consistency": consistency, + "received_responses": received_responses, + "required_responses": required_responses, + "failures": failures, + "error_code_map": error_code_map, + "write_type": write_type, } def to_exception(self): @@ -334,17 +387,17 @@ def to_exception(self): class CDCWriteException(RequestExecutionException): - summary = 'Failed to execute write due to CDC space exhaustion.' + summary = "Failed to execute write due to CDC space exhaustion." error_code = 0x1600 class SyntaxException(RequestValidationException): - summary = 'Syntax error in CQL query' + summary = "Syntax error in CQL query" error_code = 0x2000 class UnauthorizedErrorMessage(RequestValidationException): - summary = 'Unauthorized' + summary = "Unauthorized" error_code = 0x2100 def to_exception(self): @@ -352,7 +405,7 @@ def to_exception(self): class InvalidRequestException(RequestValidationException): - summary = 'Invalid query' + summary = "Invalid query" error_code = 0x2200 def to_exception(self): @@ -360,12 +413,12 @@ def to_exception(self): class ConfigurationException(RequestValidationException): - summary = 'Query invalid because of configuration issue' + summary = "Query invalid because of configuration issue" error_code = 0x2300 class PreparedQueryNotFound(RequestValidationException): - summary = 'Matching prepared statement not found on this node' + summary = "Matching prepared statement not found on this node" error_code = 0x2500 @staticmethod @@ -375,47 +428,45 @@ def recv_error_info(f, protocol_version): class AlreadyExistsException(ConfigurationException): - summary = 'Item already exists' + summary = "Item already exists" error_code = 0x2400 @staticmethod def recv_error_info(f, protocol_version): return { - 'keyspace': read_string(f), - 'table': read_string(f), + "keyspace": read_string(f), + "table": read_string(f), } def to_exception(self): return AlreadyExists(**self.info) + class RateLimitReachedException(ConfigurationException): - summary= 'Rate limit was exceeded for a partition affected by the request' + summary = "Rate limit was exceeded for a partition affected by the request" error_code = 0x4321 @staticmethod def recv_error_info(f, protocol_version): return { - 'op_type': OperationType(read_byte(f)), - 'rejected_by_coordinator': read_byte(f) != 0 + "op_type": OperationType(read_byte(f)), + "rejected_by_coordinator": read_byte(f) != 0, } def to_exception(self): return RateLimitReached(**self.info) + class ClientWriteError(RequestExecutionException): - summary = 'Client write failure.' + summary = "Client write failure." error_code = 0x8000 class StartupMessage(_MessageType): opcode = 0x01 - name = 'STARTUP' + name = "STARTUP" - KNOWN_OPTION_KEYS = set(( - 'CQL_VERSION', - 'COMPRESSION', - 'NO_COMPACT' - )) + KNOWN_OPTION_KEYS = set(("CQL_VERSION", "COMPRESSION", "NO_COMPACT")) def __init__(self, cqlversion, options): self.cqlversion = cqlversion @@ -423,13 +474,13 @@ def __init__(self, cqlversion, options): def send_body(self, f, protocol_version): optmap = self.options.copy() - optmap['CQL_VERSION'] = self.cqlversion + optmap["CQL_VERSION"] = self.cqlversion write_stringmap(f, optmap) class ReadyMessage(_MessageType): opcode = 0x02 - name = 'READY' + name = "READY" @classmethod def recv_body(cls, *args): @@ -438,7 +489,7 @@ def recv_body(cls, *args): class AuthenticateMessage(_MessageType): opcode = 0x03 - name = 'AUTHENTICATE' + name = "AUTHENTICATE" def __init__(self, authenticator): self.authenticator = authenticator @@ -451,7 +502,7 @@ def recv_body(cls, f, *args): class CredentialsMessage(_MessageType): opcode = 0x04 - name = 'CREDENTIALS' + name = "CREDENTIALS" def __init__(self, creds): self.creds = creds @@ -461,7 +512,8 @@ def send_body(self, f, protocol_version): raise UnsupportedOperation( "Credentials-based authentication is not supported with " "protocol version 2 or higher. Use the SASL authentication " - "mechanism instead.") + "mechanism instead." + ) write_short(f, len(self.creds)) for credkey, credval in self.creds.items(): write_string(f, credkey) @@ -470,7 +522,7 @@ def send_body(self, f, protocol_version): class AuthChallengeMessage(_MessageType): opcode = 0x0E - name = 'AUTH_CHALLENGE' + name = "AUTH_CHALLENGE" def __init__(self, challenge): self.challenge = challenge @@ -482,7 +534,7 @@ def recv_body(cls, f, *args): class AuthResponseMessage(_MessageType): opcode = 0x0F - name = 'AUTH_RESPONSE' + name = "AUTH_RESPONSE" def __init__(self, response): self.response = response @@ -493,7 +545,7 @@ def send_body(self, f, protocol_version): class AuthSuccessMessage(_MessageType): opcode = 0x10 - name = 'AUTH_SUCCESS' + name = "AUTH_SUCCESS" def __init__(self, token): self.token = token @@ -505,7 +557,7 @@ def recv_body(cls, f, *args): class OptionsMessage(_MessageType): opcode = 0x05 - name = 'OPTIONS' + name = "OPTIONS" def send_body(self, f, protocol_version): pass @@ -513,7 +565,7 @@ def send_body(self, f, protocol_version): class SupportedMessage(_MessageType): opcode = 0x06 - name = 'SUPPORTED' + name = "SUPPORTED" def __init__(self, cql_versions, options): self.cql_versions = cql_versions @@ -522,7 +574,7 @@ def __init__(self, cql_versions, options): @classmethod def recv_body(cls, f, *args): options = read_stringmultimap(f) - cql_versions = options.pop('CQL_VERSION') + cql_versions = options.pop("CQL_VERSION") return cls(cql_versions=cql_versions, options=options) @@ -541,11 +593,18 @@ def recv_body(cls, f, *args): class _QueryMessage(_MessageType): - - def __init__(self, query_params, consistency_level, - serial_consistency_level=None, fetch_size=None, - paging_state=None, timestamp=None, skip_meta=False, - continuous_paging_options=None, keyspace=None): + def __init__( + self, + query_params, + consistency_level, + serial_consistency_level=None, + fetch_size=None, + paging_state=None, + timestamp=None, + skip_meta=False, + continuous_paging_options=None, + keyspace=None, + ): self.query_params = query_params self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level @@ -579,7 +638,8 @@ def _write_query_params(self, f, protocol_version): else: raise UnsupportedOperation( "Keyspaces may only be set on queries with protocol version " - "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version." + ) if ProtocolVersion.uses_int_query_flags(protocol_version): write_uint(f, flags) @@ -608,13 +668,31 @@ def _write_paging_options(self, f, paging_options, protocol_version): class QueryMessage(_QueryMessage): opcode = 0x07 - name = 'QUERY' - - def __init__(self, query, consistency_level, serial_consistency_level=None, - fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + name = "QUERY" + + def __init__( + self, + query, + consistency_level, + serial_consistency_level=None, + fetch_size=None, + paging_state=None, + timestamp=None, + continuous_paging_options=None, + keyspace=None, + ): self.query = query - super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, - paging_state, timestamp, False, continuous_paging_options, keyspace) + super(QueryMessage, self).__init__( + None, + consistency_level, + serial_consistency_level, + fetch_size, + paging_state, + timestamp, + False, + continuous_paging_options, + keyspace, + ) def send_body(self, f, protocol_version): write_longstring(f, self.query) @@ -623,16 +701,33 @@ def send_body(self, f, protocol_version): class ExecuteMessage(_QueryMessage): opcode = 0x0A - name = 'EXECUTE' - - def __init__(self, query_id, query_params, consistency_level, - serial_consistency_level=None, fetch_size=None, - paging_state=None, timestamp=None, skip_meta=False, - continuous_paging_options=None, result_metadata_id=None): + name = "EXECUTE" + + def __init__( + self, + query_id, + query_params, + consistency_level, + serial_consistency_level=None, + fetch_size=None, + paging_state=None, + timestamp=None, + skip_meta=False, + continuous_paging_options=None, + result_metadata_id=None, + ): self.query_id = query_id self.result_metadata_id = result_metadata_id - super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, - paging_state, timestamp, skip_meta, continuous_paging_options) + super(ExecuteMessage, self).__init__( + query_params, + consistency_level, + serial_consistency_level, + fetch_size, + paging_state, + timestamp, + skip_meta, + continuous_paging_options, + ) def _write_query_params(self, f, protocol_version): super(ExecuteMessage, self)._write_query_params(f, protocol_version) @@ -655,14 +750,18 @@ def send_body(self, f, protocol_version): class ResultMessage(_MessageType): opcode = 0x08 - name = 'RESULT' + name = "RESULT" kind = None results = None paging_state = None # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) - type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) + type_codes = _cqltypes_by_code = dict( + (v, globals()[k]) + for k, v in type_codes.__dict__.items() + if not k.startswith("_") + ) _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _HAS_MORE_PAGES_FLAG = 0x0002 @@ -691,28 +790,66 @@ class ResultMessage(_MessageType): def __init__(self, kind): self.kind = kind - def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): + def recv( + self, + f, + protocol_version, + protocol_features, + user_type_map, + result_metadata, + column_encryption_policy, + ): if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: - self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + self.recv_results_rows( + f, + protocol_version, + user_type_map, + result_metadata, + column_encryption_policy, + ) elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: - self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map) + self.recv_results_prepared( + f, protocol_version, protocol_features, user_type_map + ) elif self.kind == RESULT_KIND_SCHEMA_CHANGE: self.recv_results_schema_change(f, protocol_version) else: raise DriverException("Unknown RESULT kind: %d" % self.kind) @classmethod - def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): + def recv_body( + cls, + f, + protocol_version, + protocol_features, + user_type_map, + result_metadata, + column_encryption_policy, + ): kind = read_int(f) msg = cls(kind) - msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) + msg.recv( + f, + protocol_version, + protocol_features, + user_type_map, + result_metadata, + column_encryption_policy, + ) return msg - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows( + self, + f, + protocol_version, + user_type_map, + result_metadata, + column_encryption_policy, + ): self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) @@ -722,13 +859,23 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val + uses_ce = ( + column_encryption_policy + and column_encryption_policy.contains_column(col_desc) + ) + col_type = ( + column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] + ) + raw_bytes = ( + column_encryption_policy.decrypt(col_desc, val) if uses_ce else val + ) return col_type.from_binary(raw_bytes, protocol_version) def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + return tuple( + decode_val(val, col_md, col_desc) + for val, col_md, col_desc in zip(row, column_metadata, col_descs) + ) try: self.parsed_rows = [decode_row(row) for row in rows] @@ -738,17 +885,22 @@ def decode_row(row): try: decode_val(val, col_md, col_desc) except Exception as e: - raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], - col_md[3].cql_parameterized_type(), - str(e))) - - def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): + raise DriverException( + 'Failed decoding result column "%s" of type %s: %s' + % (col_md[2], col_md[3].cql_parameterized_type(), str(e)) + ) + + def recv_results_prepared( + self, f, protocol_version, protocol_features, user_type_map + ): self.query_id = read_binary_string(f) if ProtocolVersion.uses_prepared_metadata(protocol_version): self.result_metadata_id = read_binary_string(f) else: self.result_metadata_id = None - self.recv_prepared_metadata(f, protocol_version, protocol_features, user_type_map) + self.recv_prepared_metadata( + f, protocol_version, protocol_features, user_type_map + ) def recv_results_metadata(self, f, user_type_map): flags = read_int(f) @@ -786,9 +938,15 @@ def recv_results_metadata(self, f, user_type_map): self.column_metadata = column_metadata - def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_type_map): + def recv_prepared_metadata( + self, f, protocol_version, protocol_features, user_type_map + ): flags = read_int(f) - self.is_lwt = protocol_features.lwt_info.get_lwt_flag(flags) if protocol_features.lwt_info is not None else False + self.is_lwt = ( + protocol_features.lwt_info.get_lwt_flag(flags) + if protocol_features.lwt_info is not None + else False + ) colcount = read_int(f) pk_indexes = None if protocol_version >= 4: @@ -825,8 +983,10 @@ def read_type(cls, f, user_type_map): try: typeclass = cls.type_codes[optid] except KeyError: - raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" - " entire result set." % (optid,)) + raise NotSupportedError( + "Unknown data type code 0x%04x. Have to skip" + " entire result set." % (optid,) + ) if typeclass in (ListType, SetType): subtype = cls.read_type(f, user_type_map) typeclass = typeclass.apply_parameters((subtype,)) @@ -842,8 +1002,12 @@ def read_type(cls, f, user_type_map): ks = read_string(f) udt_name = read_string(f) num_fields = read_short(f) - names, types = zip(*((read_string(f), cls.read_type(f, user_type_map)) - for _ in range(num_fields))) + names, types = zip( + *( + (read_string(f), cls.read_type(f, user_type_map)) + for _ in range(num_fields) + ) + ) specialized_type = typeclass.make_udt_class(ks, udt_name, names, types) specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name) typeclass = specialized_type @@ -860,7 +1024,7 @@ def recv_row(f, colcount): class PrepareMessage(_MessageType): opcode = 0x09 - name = 'PREPARE' + name = "PREPARE" def __init__(self, query, keyspace=None): self.query = query @@ -877,7 +1041,8 @@ def send_body(self, f, protocol_version): else: raise UnsupportedOperation( "Keyspaces may only be set on queries with protocol version " - "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version." + ) if ProtocolVersion.uses_prepare_flags(protocol_version): write_uint(f, flags) @@ -889,7 +1054,8 @@ def send_body(self, f, protocol_version): "protocol version {pv}, which doesn't support flags" "in prepared statements." "Consider setting Cluster.protocol_version to 5 or DSE_V2." - "".format(flags=flags, pv=protocol_version)) + "".format(flags=flags, pv=protocol_version) + ) if ProtocolVersion.uses_keyspace_flag(protocol_version): if self.keyspace: @@ -898,11 +1064,17 @@ def send_body(self, f, protocol_version): class BatchMessage(_MessageType): opcode = 0x0D - name = 'BATCH' - - def __init__(self, batch_type, queries, consistency_level, - serial_consistency_level=None, timestamp=None, - keyspace=None): + name = "BATCH" + + def __init__( + self, + batch_type, + queries, + consistency_level, + serial_consistency_level=None, + timestamp=None, + keyspace=None, + ): self.batch_type = batch_type self.queries = queries self.consistency_level = consistency_level @@ -937,7 +1109,8 @@ def send_body(self, f, protocol_version): else: raise UnsupportedOperation( "Keyspaces may only be set on queries with protocol version " - "5 or higher. Consider setting Cluster.protocol_version to 5.") + "5 or higher. Consider setting Cluster.protocol_version to 5." + ) if ProtocolVersion.uses_int_query_flags(protocol_version): write_int(f, flags) else: @@ -953,17 +1126,14 @@ def send_body(self, f, protocol_version): write_string(f, self.keyspace) -known_event_types = frozenset(( - 'TOPOLOGY_CHANGE', - 'STATUS_CHANGE', - 'SCHEMA_CHANGE', - 'CLIENT_ROUTES_CHANGE' -)) +known_event_types = frozenset( + ("TOPOLOGY_CHANGE", "STATUS_CHANGE", "SCHEMA_CHANGE", "CLIENT_ROUTES_CHANGE") +) class RegisterMessage(_MessageType): opcode = 0x0B - name = 'REGISTER' + name = "REGISTER" def __init__(self, event_list): self.event_list = event_list @@ -974,7 +1144,7 @@ def send_body(self, f, protocol_version): class EventMessage(_MessageType): opcode = 0x0C - name = 'EVENT' + name = "EVENT" def __init__(self, event_type, event_args): self.event_type = event_type @@ -984,9 +1154,11 @@ def __init__(self, event_type, event_args): def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: - read_method = getattr(cls, 'recv_' + event_type.lower()) - return cls(event_type=event_type, event_args=read_method(f, protocol_version)) - raise NotSupportedError('Unknown event type %r' % event_type) + read_method = getattr(cls, "recv_" + event_type.lower()) + return cls( + event_type=event_type, event_args=read_method(f, protocol_version) + ) + raise NotSupportedError("Unknown event type %r" % event_type) @classmethod def recv_client_routes_change(cls, f, protocol_version): @@ -994,7 +1166,9 @@ def recv_client_routes_change(cls, f, protocol_version): change_type = read_string(f) connection_ids = read_stringlist(f) host_ids = read_stringlist(f) - return dict(change_type=change_type, connection_ids=connection_ids, host_ids=host_ids) + return dict( + change_type=change_type, connection_ids=connection_ids, host_ids=host_ids + ) @classmethod def recv_topology_change(cls, f, protocol_version): @@ -1016,26 +1190,33 @@ def recv_schema_change(cls, f, protocol_version): change_type = read_string(f) target = read_string(f) keyspace = read_string(f) - event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace} + event = { + "target_type": target, + "change_type": change_type, + "keyspace": keyspace, + } if target != SchemaTargetType.KEYSPACE: target_name = read_string(f) if target == SchemaTargetType.FUNCTION: - event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) + event["function"] = UserFunctionDescriptor( + target_name, [read_string(f) for _ in range(read_short(f))] + ) elif target == SchemaTargetType.AGGREGATE: - event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) + event["aggregate"] = UserAggregateDescriptor( + target_name, [read_string(f) for _ in range(read_short(f))] + ) else: event[target.lower()] = target_name return event class ReviseRequestMessage(_MessageType): - class RevisionType(object): PAGING_CANCEL = 1 PAGING_BACKPRESSURE = 2 opcode = 0xFF - name = 'REVISE_REQUEST' + name = "REVISE_REQUEST" def __init__(self, op_type, op_id, next_pages=0): self.op_type = op_type @@ -1047,10 +1228,13 @@ def send_body(self, f, protocol_version): write_int(f, self.op_id) if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE: if self.next_pages <= 0: - raise UnsupportedOperation("Continuous paging backpressure requires next_pages > 0") + raise UnsupportedOperation( + "Continuous paging backpressure requires next_pages > 0" + ) else: raise UnsupportedOperation( - "Continuous paging backpressure is not supported.") + "Continuous paging backpressure is not supported." + ) class _ProtocolHandler(object): @@ -1075,7 +1259,9 @@ class _ProtocolHandler(object): """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler""" @classmethod - def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): + def encode_message( + cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version + ): """ Encodes a message using the specified frame parameters, and compressor @@ -1087,7 +1273,9 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta flags = 0 if msg.custom_payload: if protocol_version < 4: - raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") + raise UnsupportedOperation( + "Custom key/value payloads can only be used with protocol version 4 or higher" + ) flags |= CUSTOM_PAYLOAD_FLAG if msg.tracing: @@ -1100,7 +1288,9 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta buff.seek(9) # With checksumming, the compression is done at the segment frame encoding - if (compressor and not ProtocolVersion.has_checksumming_support(protocol_version)): + if compressor and not ProtocolVersion.has_checksumming_support( + protocol_version + ): body = io.BytesIO() if msg.custom_payload: write_bytesmap(body, msg.custom_payload) @@ -1133,8 +1323,18 @@ def _write_header(f, version, flags, stream_id, opcode, length): write_int(f, length) @classmethod - def decode_message(cls, protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, body, - decompressor, result_metadata): + def decode_message( + cls, + protocol_version, + protocol_features, + user_type_map, + stream_id, + flags, + opcode, + body, + decompressor, + result_metadata, + ): """ Decodes a native protocol message body @@ -1147,8 +1347,10 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre :param decompressor: optional decompression function to inflate the body :return: a message decoded from the body and frame attributes """ - if (not ProtocolVersion.has_checksumming_support(protocol_version) and - flags & COMPRESSED_FLAG): + if ( + not ProtocolVersion.has_checksumming_support(protocol_version) + and flags & COMPRESSED_FLAG + ): if decompressor is None: raise RuntimeError("No de-compressor available for compressed frame!") body = decompressor(body) @@ -1179,7 +1381,14 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) msg_class = cls.message_types_by_opcode[opcode] - msg = msg_class.recv_body(body, protocol_version, protocol_features, user_type_map, result_metadata, cls.column_encryption_policy) + msg = msg_class.recv_body( + body, + protocol_version, + protocol_features, + user_type_map, + result_metadata, + cls.column_encryption_policy, + ) msg.stream_id = stream_id msg.trace_id = trace_id msg.custom_payload = custom_payload @@ -1217,6 +1426,7 @@ class FastResultMessage(ResultMessage): Cython version of Result Message that has a faster implementation of recv_results_row. """ + # type_codes = ResultMessage.type_codes.copy() code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) recv_results_rows = make_recv_results_rows(colparser) @@ -1237,6 +1447,7 @@ class CythonProtocolHandler(_ProtocolHandler): if HAVE_CYTHON: from cassandra.obj_parser import ListParser, LazyParser + ProtocolHandler = cython_protocol_handler(ListParser()) LazyProtocolHandler = cython_protocol_handler(LazyParser()) else: @@ -1247,6 +1458,7 @@ class CythonProtocolHandler(_ProtocolHandler): if HAVE_CYTHON and HAVE_NUMPY: from cassandra.numpy_parser import NumpyParser + NumpyProtocolHandler = cython_protocol_handler(NumpyParser()) else: NumpyProtocolHandler = None @@ -1322,7 +1534,7 @@ def write_consistency_level(f, cl): def read_string(f): size = read_short(f) contents = f.read(size) - return contents.decode('utf8') + return contents.decode("utf8") def read_binary_string(f): @@ -1333,7 +1545,7 @@ def read_binary_string(f): def write_string(f, s): if isinstance(s, str): - s = s.encode('utf8') + s = s.encode("utf8") write_short(f, len(s)) f.write(s) @@ -1345,12 +1557,12 @@ def read_binary_longstring(f): def read_longstring(f): - return read_binary_longstring(f).decode('utf8') + return read_binary_longstring(f).decode("utf8") def write_longstring(f, s): if isinstance(s, str): - s = s.encode('utf8') + s = s.encode("utf8") write_int(f, len(s)) f.write(s) @@ -1460,7 +1672,7 @@ def read_inet(f): def write_inet(f, addrtuple): addr, port = addrtuple - if ':' in addr: + if ":" in addr: addrfam = socket.AF_INET6 else: addrfam = socket.AF_INET diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 34b4ab5964..91d939e872 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import import unittest @@ -38,8 +37,15 @@ from cassandra.cqlengine import operators from cassandra.util import uuid_from_time from cassandra.cqlengine.connection import get_session -from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ - greaterthanorequalcass30, TestCluster, requires_collection_indexes +from tests.integration import ( + PROTOCOL_VERSION, + CASSANDRA_VERSION, + greaterthancass20, + greaterthancass21, + greaterthanorequalcass30, + TestCluster, + requires_collection_indexes, +) from tests.integration.cqlengine import execute_count, DEFAULT_KEYSPACE import pytest @@ -56,7 +62,7 @@ def utcoffset(self, dt): return self._offset def tzname(self, dt): - return 'TzOffset: {}'.format(self._offset.hours) + return "TzOffset: {}".format(self._offset.hours) def dst(self, dt): return timedelta(0) @@ -73,7 +79,6 @@ class TestModel(Model): class IndexedTestModel(Model): - test_id = columns.Integer(primary_key=True) attempt_id = columns.Integer(index=True) description = columns.Text() @@ -82,7 +87,6 @@ class IndexedTestModel(Model): class CustomIndexedTestModel(Model): - test_id = columns.Integer(primary_key=True) description = columns.Text(custom_index=True) indexed = columns.Text(index=True) @@ -90,7 +94,6 @@ class CustomIndexedTestModel(Model): class IndexedCollectionsTestModel(Model): - test_id = columns.Integer(primary_key=True) attempt_id = columns.Integer(index=True) description = columns.Text() @@ -114,7 +117,6 @@ class TestMultiClusteringModel(Model): class TestQuerySetOperation(BaseCassEngTestCase): - def test_query_filter_parsing(self): """ Tests the queryset filter method parses it's kwargs properly @@ -137,7 +139,7 @@ def test_query_filter_parsing(self): assert op.value == 1 def test_query_expression_parsing(self): - """ Tests that query experessions are evaluated properly """ + """Tests that query experessions are evaluated properly""" query1 = TestModel.filter(TestModel.test_id == 5) assert len(query1._where) == 1 @@ -221,10 +223,10 @@ def test_queryset_with_distinct(self): query1 = TestModel.objects.distinct() assert len(query1._distinct_fields) == 1 - query2 = TestModel.objects.distinct(['test_id']) + query2 = TestModel.objects.distinct(["test_id"]) assert len(query2._distinct_fields) == 1 - query3 = TestModel.objects.distinct(['test_id', 'attempt_id']) + query3 = TestModel.objects.distinct(["test_id", "attempt_id"]) assert len(query3._distinct_fields) == 2 def test_defining_only_fields(self): @@ -238,37 +240,39 @@ def test_defining_only_fields(self): @test_category object_mapper """ # simple only definition - q = TestModel.objects.only(['attempt_id', 'description']) - assert q._select_fields() == ['attempt_id', 'description'] + q = TestModel.objects.only(["attempt_id", "description"]) + assert q._select_fields() == ["attempt_id", "description"] with pytest.raises(query.QueryException): - TestModel.objects.only(['nonexistent_field']) + TestModel.objects.only(["nonexistent_field"]) # Cannot define more than once only fields with pytest.raises(query.QueryException): - TestModel.objects.only(['description']).only(['attempt_id']) + TestModel.objects.only(["description"]).only(["attempt_id"]) # only with defer fields - q = TestModel.objects.only(['attempt_id', 'description']) - q = q.defer(['description']) - assert q._select_fields() == ['attempt_id'] + q = TestModel.objects.only(["attempt_id", "description"]) + q = q.defer(["description"]) + assert q._select_fields() == ["attempt_id"] # Eliminate all results confirm exception is thrown - q = TestModel.objects.only(['description']) - q = q.defer(['description']) + q = TestModel.objects.only(["description"]) + q = q.defer(["description"]) with pytest.raises(query.QueryException): q._select_fields() - q = TestModel.objects.filter(test_id=0).only(['test_id', 'attempt_id', 'description']) - assert q._select_fields() == ['attempt_id', 'description'] + q = TestModel.objects.filter(test_id=0).only( + ["test_id", "attempt_id", "description"] + ) + assert q._select_fields() == ["attempt_id", "description"] # no fields to select with pytest.raises(query.QueryException): - q = TestModel.objects.only(['test_id']).defer(['test_id']) + q = TestModel.objects.only(["test_id"]).defer(["test_id"]) q._select_fields() with pytest.raises(query.QueryException): - q = TestModel.objects.filter(test_id=0).only(['test_id']) + q = TestModel.objects.filter(test_id=0).only(["test_id"]) q._select_fields() def test_defining_defer_fields(self): @@ -284,39 +288,45 @@ def test_defining_defer_fields(self): """ # simple defer definition - q = TestModel.objects.defer(['attempt_id', 'description']) - assert q._select_fields() == ['test_id', 'expected_result', 'test_result'] + q = TestModel.objects.defer(["attempt_id", "description"]) + assert q._select_fields() == ["test_id", "expected_result", "test_result"] with pytest.raises(query.QueryException): - TestModel.objects.defer(['nonexistent_field']) + TestModel.objects.defer(["nonexistent_field"]) # defer more than one - q = TestModel.objects.defer(['attempt_id', 'description']) - q = q.defer(['expected_result']) - assert q._select_fields() == ['test_id', 'test_result'] + q = TestModel.objects.defer(["attempt_id", "description"]) + q = q.defer(["expected_result"]) + assert q._select_fields() == ["test_id", "test_result"] # defer with only - q = TestModel.objects.defer(['description', 'attempt_id']) - q = q.only(['description', 'test_id']) - assert q._select_fields() == ['test_id'] + q = TestModel.objects.defer(["description", "attempt_id"]) + q = q.only(["description", "test_id"]) + assert q._select_fields() == ["test_id"] # Eliminate all results confirm exception is thrown - q = TestModel.objects.defer(['description', 'attempt_id']) - q = q.only(['description']) + q = TestModel.objects.defer(["description", "attempt_id"]) + q = q.only(["description"]) with pytest.raises(query.QueryException): q._select_fields() # implicit defer q = TestModel.objects.filter(test_id=0) - assert q._select_fields() == ['attempt_id', 'description', 'expected_result', 'test_result'] + assert q._select_fields() == [ + "attempt_id", + "description", + "expected_result", + "test_result", + ] # when all fields are defered, it fallbacks select the partition keys - q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) - assert q._select_fields() == ['test_id'] + q = TestModel.objects.defer( + ["test_id", "attempt_id", "description", "expected_result", "test_result"] + ) + assert q._select_fields() == ["test_id"] class BaseQuerySetUsage(BaseCassEngTestCase): - @classmethod def setUpClass(cls): super(BaseQuerySetUsage, cls).setUpClass() @@ -329,55 +339,224 @@ def setUpClass(cls): sync_table(CustomIndexedTestModel) sync_table(TestMultiClusteringModel) - TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) - TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30) - TestModel.objects.create(test_id=0, attempt_id=2, description='try3', expected_result=15, test_result=30) - TestModel.objects.create(test_id=0, attempt_id=3, description='try4', expected_result=20, test_result=25) - - TestModel.objects.create(test_id=1, attempt_id=0, description='try5', expected_result=5, test_result=25) - TestModel.objects.create(test_id=1, attempt_id=1, description='try6', expected_result=10, test_result=25) - TestModel.objects.create(test_id=1, attempt_id=2, description='try7', expected_result=15, test_result=25) - TestModel.objects.create(test_id=1, attempt_id=3, description='try8', expected_result=20, test_result=20) - - TestModel.objects.create(test_id=2, attempt_id=0, description='try9', expected_result=50, test_result=40) - TestModel.objects.create(test_id=2, attempt_id=1, description='try10', expected_result=60, test_result=40) - TestModel.objects.create(test_id=2, attempt_id=2, description='try11', expected_result=70, test_result=45) - TestModel.objects.create(test_id=2, attempt_id=3, description='try12', expected_result=75, test_result=45) - - IndexedTestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) - IndexedTestModel.objects.create(test_id=1, attempt_id=1, description='try2', expected_result=10, test_result=30) - IndexedTestModel.objects.create(test_id=2, attempt_id=2, description='try3', expected_result=15, test_result=30) - IndexedTestModel.objects.create(test_id=3, attempt_id=3, description='try4', expected_result=20, test_result=25) - - IndexedTestModel.objects.create(test_id=4, attempt_id=0, description='try5', expected_result=5, test_result=25) - IndexedTestModel.objects.create(test_id=5, attempt_id=1, description='try6', expected_result=10, test_result=25) - IndexedTestModel.objects.create(test_id=6, attempt_id=2, description='try7', expected_result=15, test_result=25) - IndexedTestModel.objects.create(test_id=7, attempt_id=3, description='try8', expected_result=20, test_result=20) - - IndexedTestModel.objects.create(test_id=8, attempt_id=0, description='try9', expected_result=50, test_result=40) - IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60, - test_result=40) - IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70, - test_result=45) - IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, - test_result=45) - - if CASSANDRA_VERSION >= Version('2.1'): + TestModel.objects.create( + test_id=0, + attempt_id=0, + description="try1", + expected_result=5, + test_result=30, + ) + TestModel.objects.create( + test_id=0, + attempt_id=1, + description="try2", + expected_result=10, + test_result=30, + ) + TestModel.objects.create( + test_id=0, + attempt_id=2, + description="try3", + expected_result=15, + test_result=30, + ) + TestModel.objects.create( + test_id=0, + attempt_id=3, + description="try4", + expected_result=20, + test_result=25, + ) + + TestModel.objects.create( + test_id=1, + attempt_id=0, + description="try5", + expected_result=5, + test_result=25, + ) + TestModel.objects.create( + test_id=1, + attempt_id=1, + description="try6", + expected_result=10, + test_result=25, + ) + TestModel.objects.create( + test_id=1, + attempt_id=2, + description="try7", + expected_result=15, + test_result=25, + ) + TestModel.objects.create( + test_id=1, + attempt_id=3, + description="try8", + expected_result=20, + test_result=20, + ) + + TestModel.objects.create( + test_id=2, + attempt_id=0, + description="try9", + expected_result=50, + test_result=40, + ) + TestModel.objects.create( + test_id=2, + attempt_id=1, + description="try10", + expected_result=60, + test_result=40, + ) + TestModel.objects.create( + test_id=2, + attempt_id=2, + description="try11", + expected_result=70, + test_result=45, + ) + TestModel.objects.create( + test_id=2, + attempt_id=3, + description="try12", + expected_result=75, + test_result=45, + ) + + IndexedTestModel.objects.create( + test_id=0, + attempt_id=0, + description="try1", + expected_result=5, + test_result=30, + ) + IndexedTestModel.objects.create( + test_id=1, + attempt_id=1, + description="try2", + expected_result=10, + test_result=30, + ) + IndexedTestModel.objects.create( + test_id=2, + attempt_id=2, + description="try3", + expected_result=15, + test_result=30, + ) + IndexedTestModel.objects.create( + test_id=3, + attempt_id=3, + description="try4", + expected_result=20, + test_result=25, + ) + + IndexedTestModel.objects.create( + test_id=4, + attempt_id=0, + description="try5", + expected_result=5, + test_result=25, + ) + IndexedTestModel.objects.create( + test_id=5, + attempt_id=1, + description="try6", + expected_result=10, + test_result=25, + ) + IndexedTestModel.objects.create( + test_id=6, + attempt_id=2, + description="try7", + expected_result=15, + test_result=25, + ) + IndexedTestModel.objects.create( + test_id=7, + attempt_id=3, + description="try8", + expected_result=20, + test_result=20, + ) + + IndexedTestModel.objects.create( + test_id=8, + attempt_id=0, + description="try9", + expected_result=50, + test_result=40, + ) + IndexedTestModel.objects.create( + test_id=9, + attempt_id=1, + description="try10", + expected_result=60, + test_result=40, + ) + IndexedTestModel.objects.create( + test_id=10, + attempt_id=2, + description="try11", + expected_result=70, + test_result=45, + ) + IndexedTestModel.objects.create( + test_id=11, + attempt_id=3, + description="try12", + expected_result=75, + test_result=45, + ) + + if CASSANDRA_VERSION >= Version("2.1"): drop_table(IndexedCollectionsTestModel) sync_table(IndexedCollectionsTestModel) - IndexedCollectionsTestModel.objects.create(test_id=12, attempt_id=3, description='list12', expected_result=75, - test_result=45, test_list=[1, 2, 42], test_set=set([1, 2, 3]), - test_map={'1': 1, '2': 2, '3': 3}) - IndexedCollectionsTestModel.objects.create(test_id=13, attempt_id=3, description='list13', expected_result=75, - test_result=45, test_list=[3, 4, 5], test_set=set([4, 5, 42]), - test_map={'1': 5, '2': 6, '3': 7}) - IndexedCollectionsTestModel.objects.create(test_id=14, attempt_id=3, description='list14', expected_result=75, - test_result=45, test_list=[1, 2, 3], test_set=set([1, 2, 3]), - test_map={'1': 1, '2': 2, '3': 42}) - - IndexedCollectionsTestModel.objects.create(test_id=15, attempt_id=4, description='list14', expected_result=75, - test_result=45, test_list_no_index=[1, 2, 3], test_set_no_index=set([1, 2, 3]), - test_map_no_index={'1': 1, '2': 2, '3': 42}) + IndexedCollectionsTestModel.objects.create( + test_id=12, + attempt_id=3, + description="list12", + expected_result=75, + test_result=45, + test_list=[1, 2, 42], + test_set=set([1, 2, 3]), + test_map={"1": 1, "2": 2, "3": 3}, + ) + IndexedCollectionsTestModel.objects.create( + test_id=13, + attempt_id=3, + description="list13", + expected_result=75, + test_result=45, + test_list=[3, 4, 5], + test_set=set([4, 5, 42]), + test_map={"1": 5, "2": 6, "3": 7}, + ) + IndexedCollectionsTestModel.objects.create( + test_id=14, + attempt_id=3, + description="list14", + expected_result=75, + test_result=45, + test_list=[1, 2, 3], + test_set=set([1, 2, 3]), + test_map={"1": 1, "2": 2, "3": 42}, + ) + + IndexedCollectionsTestModel.objects.create( + test_id=15, + attempt_id=4, + description="list14", + expected_result=75, + test_result=45, + test_list_no_index=[1, 2, 3], + test_set_no_index=set([1, 2, 3]), + test_map_no_index={"1": 1, "2": 2, "3": 42}, + ) @classmethod def tearDownClass(cls): @@ -387,12 +566,12 @@ def tearDownClass(cls): drop_table(CustomIndexedTestModel) drop_table(TestMultiClusteringModel) + @requires_collection_indexes class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): - @execute_count(2) def test_count(self): - """ Tests that adding filtering statements affects the count query as expected """ + """Tests that adding filtering statements affects the count query as expected""" assert TestModel.objects.count() == 12 q = TestModel.objects(test_id=0) @@ -400,7 +579,7 @@ def test_count(self): @execute_count(2) def test_query_expression_count(self): - """ Tests that adding query statements affects the count query as expected """ + """Tests that adding query statements affects the count query as expected""" assert TestModel.objects.count() == 12 q = TestModel.objects(TestModel.test_id == 0) @@ -408,7 +587,7 @@ def test_query_expression_count(self): @execute_count(3) def test_iteration(self): - """ Tests that iterating over a query set pulls back all of the expected results """ + """Tests that iterating over a query set pulls back all of the expected results""" q = TestModel.objects(test_id=0) # tuple of expected attempt_id, expected_result values compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) @@ -442,9 +621,12 @@ def test_iteration(self): @execute_count(2) def test_multiple_iterations_work_properly(self): - """ Tests that iterating over a query set more than once works """ + """Tests that iterating over a query set more than once works""" # test with both the filtering method and the query method - for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): + for q in ( + TestModel.objects(test_id=0), + TestModel.objects(TestModel.test_id == 0), + ): # tuple of expected attempt_id, expected_result values compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) for t in q: @@ -466,8 +648,11 @@ def test_multiple_iterators_are_isolated(self): """ tests that the use of one iterator does not affect the behavior of another """ - for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): - q = q.order_by('attempt_id') + for q in ( + TestModel.objects(test_id=0), + TestModel.objects(TestModel.test_id == 0), + ): + q = q.order_by("attempt_id") expected_order = [0, 1, 2, 3] iter1 = iter(q) iter2 = iter(q) @@ -536,15 +721,16 @@ def test_get_multipleobjects_exception(self): TestModel.objects.get(test_id=1) def test_allow_filtering_flag(self): - """ - """ + """ """ + @execute_count(4) def test_non_quality_filtering(): class NonEqualityFilteringModel(Model): - example_id = columns.UUID(primary_key=True, default=uuid.uuid4) - sequence_id = columns.Integer(primary_key=True) # sequence_id is a clustering key + sequence_id = columns.Integer( + primary_key=True + ) # sequence_id is a clustering key example_type = columns.Integer(index=True) created_at = columns.DateTime() @@ -553,17 +739,25 @@ class NonEqualityFilteringModel(Model): # setup table, etc. - NonEqualityFilteringModel.create(sequence_id=1, example_type=0, created_at=datetime.now()) - NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) - NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) - - qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() + NonEqualityFilteringModel.create( + sequence_id=1, example_type=0, created_at=datetime.now() + ) + NonEqualityFilteringModel.create( + sequence_id=3, example_type=0, created_at=datetime.now() + ) + NonEqualityFilteringModel.create( + sequence_id=5, example_type=1, created_at=datetime.now() + ) + + qa = NonEqualityFilteringModel.objects( + NonEqualityFilteringModel.sequence_id > 3 + ).allow_filtering() num = qa.count() assert num == 1, num + @requires_collection_indexes class TestQuerySetDistinct(BaseQuerySetUsage): - @execute_count(1) def test_distinct_without_parameter(self): q = TestModel.objects.distinct() @@ -571,32 +765,32 @@ def test_distinct_without_parameter(self): @execute_count(1) def test_distinct_with_parameter(self): - q = TestModel.objects.distinct(['test_id']) + q = TestModel.objects.distinct(["test_id"]) assert len(q) == 3 @execute_count(1) def test_distinct_with_filter(self): - q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) + q = TestModel.objects.distinct(["test_id"]).filter(test_id__in=[1, 2]) assert len(q) == 2 @execute_count(1) def test_distinct_with_non_partition(self): with pytest.raises(InvalidRequest): - q = TestModel.objects.distinct(['description']).filter(test_id__in=[1, 2]) + q = TestModel.objects.distinct(["description"]).filter(test_id__in=[1, 2]) len(q) @execute_count(1) def test_zero_result(self): - q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52]) + q = TestModel.objects.distinct(["test_id"]).filter(test_id__in=[52]) assert len(q) == 0 @greaterthancass21 @execute_count(2) def test_distinct_with_explicit_count(self): - q = TestModel.objects.distinct(['test_id']) + q = TestModel.objects.distinct(["test_id"]) assert q.count() == 3 - q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) + q = TestModel.objects.distinct(["test_id"]).filter(test_id__in=[1, 2]) assert q.count() == 2 @@ -604,12 +798,12 @@ def test_distinct_with_explicit_count(self): class TestQuerySetOrdering(BaseQuerySetUsage): @execute_count(2) def test_order_by_success_case(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") expected_order = [0, 1, 2, 3] for model, expect in zip(q, expected_order): assert model.attempt_id == expect - q = q.order_by().order_by('-attempt_id') + q = q.order_by().order_by("-attempt_id") expected_order.reverse() for model, expect in zip(q, expected_order): assert model.attempt_id == expect @@ -617,19 +811,19 @@ def test_order_by_success_case(self): def test_ordering_by_non_second_primary_keys_fail(self): # kwarg filtering with pytest.raises(query.QueryException): - TestModel.objects(test_id=0).order_by('test_id') + TestModel.objects(test_id=0).order_by("test_id") # kwarg filtering with pytest.raises(query.QueryException): - TestModel.objects(TestModel.test_id == 0).order_by('test_id') + TestModel.objects(TestModel.test_id == 0).order_by("test_id") def test_ordering_by_non_primary_keys_fails(self): with pytest.raises(query.QueryException): - TestModel.objects(test_id=0).order_by('description') + TestModel.objects(test_id=0).order_by("description") def test_ordering_on_indexed_columns_fails(self): with pytest.raises(query.QueryException): - IndexedTestModel.objects(test_id=0).order_by('attempt_id') + IndexedTestModel.objects(test_id=0).order_by("attempt_id") @execute_count(8) def test_ordering_on_multiple_clustering_columns(self): @@ -639,42 +833,49 @@ def test_ordering_on_multiple_clustering_columns(self): TestMultiClusteringModel.create(one=1, two=1, three=1) TestMultiClusteringModel.create(one=1, two=1, three=3) - results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('-two', '-three') + results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by( + "-two", "-three" + ) assert [r.three for r in results] == [5, 4, 3, 2, 1] - results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('two', 'three') + results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by( + "two", "three" + ) assert [r.three for r in results] == [1, 2, 3, 4, 5] - results = TestMultiClusteringModel.objects.filter(one=1, two=1).order_by('two').order_by('three') + results = ( + TestMultiClusteringModel.objects.filter(one=1, two=1) + .order_by("two") + .order_by("three") + ) assert [r.three for r in results] == [1, 2, 3, 4, 5] @requires_collection_indexes class TestQuerySetSlicing(BaseQuerySetUsage): - @execute_count(1) def test_out_of_range_index_raises_error(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") with pytest.raises(IndexError): q[10] @execute_count(1) def test_array_indexing_works_properly(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") expected_order = [0, 1, 2, 3] for i in range(len(q)): assert q[i].attempt_id == expected_order[i] @execute_count(1) def test_negative_indexing_works_properly(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") expected_order = [0, 1, 2, 3] assert q[-1].attempt_id == expected_order[-1] assert q[-2].attempt_id == expected_order[-2] @execute_count(1) def test_slicing_works_properly(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") expected_order = [0, 1, 2, 3] for model, expect in zip(q[1:3], expected_order[1:3]): @@ -685,7 +886,7 @@ def test_slicing_works_properly(self): @execute_count(1) def test_negative_slicing(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') + q = TestModel.objects(test_id=0).order_by("attempt_id") expected_order = [0, 1, 2, 3] for model, expect in zip(q[-3:], expected_order[-3:]): @@ -706,7 +907,6 @@ def test_negative_slicing(self): @requires_collection_indexes class TestQuerySetValidation(BaseQuerySetUsage): - def test_primary_key_or_index_must_be_specified(self): """ Tests that queries that don't have an equals relation to a primary key or indexed field fail @@ -756,45 +956,73 @@ def test_custom_indexed_field_can_be_queried(self): """ with pytest.raises(query.QueryException): - list(CustomIndexedTestModel.objects.filter(data='test')) # not custom indexed + list( + CustomIndexedTestModel.objects.filter(data="test") + ) # not custom indexed # It should return InvalidRequest if target an indexed columns with pytest.raises(InvalidRequest): - list(CustomIndexedTestModel.objects.filter(indexed='test', data='test')) + list(CustomIndexedTestModel.objects.filter(indexed="test", data="test")) # It should return InvalidRequest if target an indexed columns with pytest.raises(InvalidRequest): - list(CustomIndexedTestModel.objects.filter(description='test', data='test')) + list(CustomIndexedTestModel.objects.filter(description="test", data="test")) # equals operator, server error since there is no real index, but it passes with pytest.raises(InvalidRequest): - list(CustomIndexedTestModel.objects.filter(description='test')) + list(CustomIndexedTestModel.objects.filter(description="test")) with pytest.raises(InvalidRequest): - list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + list(CustomIndexedTestModel.objects.filter(test_id=1, description="test")) # gte operator, server error since there is no real index, but it passes # this can't work with a secondary index with pytest.raises(InvalidRequest): - list(CustomIndexedTestModel.objects.filter(description__gte='test')) + list(CustomIndexedTestModel.objects.filter(description__gte="test")) with TestCluster().connect() as session: - session.execute("CREATE INDEX custom_index_cqlengine ON {}.{} (description)". - format(DEFAULT_KEYSPACE, CustomIndexedTestModel._table_name)) + session.execute( + "CREATE INDEX custom_index_cqlengine ON {}.{} (description)".format( + DEFAULT_KEYSPACE, CustomIndexedTestModel._table_name + ) + ) - list(CustomIndexedTestModel.objects.filter(description='test')) - list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + list(CustomIndexedTestModel.objects.filter(description="test")) + list(CustomIndexedTestModel.objects.filter(test_id=1, description="test")) @requires_collection_indexes class TestQuerySetDelete(BaseQuerySetUsage): - @execute_count(9) def test_delete(self): - TestModel.objects.create(test_id=3, attempt_id=0, description='try9', expected_result=50, test_result=40) - TestModel.objects.create(test_id=3, attempt_id=1, description='try10', expected_result=60, test_result=40) - TestModel.objects.create(test_id=3, attempt_id=2, description='try11', expected_result=70, test_result=45) - TestModel.objects.create(test_id=3, attempt_id=3, description='try12', expected_result=75, test_result=45) + TestModel.objects.create( + test_id=3, + attempt_id=0, + description="try9", + expected_result=50, + test_result=40, + ) + TestModel.objects.create( + test_id=3, + attempt_id=1, + description="try10", + expected_result=60, + test_result=40, + ) + TestModel.objects.create( + test_id=3, + attempt_id=2, + description="try11", + expected_result=70, + test_result=45, + ) + TestModel.objects.create( + test_id=3, + attempt_id=3, + description="try12", + expected_result=75, + test_result=45, + ) assert TestModel.objects.count() == 16 assert TestModel.objects(test_id=3).count() == 4 @@ -805,12 +1033,12 @@ def test_delete(self): assert TestModel.objects(test_id=3).count() == 0 def test_delete_without_partition_key(self): - """ Tests that attempting to delete a model without defining a partition key fails """ + """Tests that attempting to delete a model without defining a partition key fails""" with pytest.raises(query.QueryException): TestModel.objects(attempt_id=0).delete() def test_delete_without_any_where_args(self): - """ Tests that attempting to delete a whole table without any arguments will fail """ + """Tests that attempting to delete a whole table without any arguments will fail""" with pytest.raises(query.QueryException): TestModel.objects(attempt_id=0).delete() @@ -838,7 +1066,6 @@ def test_range_deletion(self): class TimeUUIDQueryModel(Model): - partition = columns.UUID(primary_key=True) time = columns.TimeUUID(primary_key=True) data = columns.Text(required=False) @@ -871,91 +1098,128 @@ def test_tzaware_datetime_support(self): TimeUUIDQueryModel.create( partition=pk, time=uuid_from_time(midpoint_utc - timedelta(minutes=1)), - data='1') + data="1", + ) TimeUUIDQueryModel.create( - partition=pk, - time=uuid_from_time(midpoint_utc), - data='2') + partition=pk, time=uuid_from_time(midpoint_utc), data="2" + ) TimeUUIDQueryModel.create( partition=pk, time=uuid_from_time(midpoint_utc + timedelta(minutes=1)), - data='3') - - assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( - TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc))] - - assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( - TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki))] + data="3", + ) - assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( - TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc))] - - assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( - TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))] + assert ["1", "2"] == [ + o.data + for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc), + ) + ] + + assert ["1", "2"] == [ + o.data + for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki), + ) + ] + + assert ["2", "3"] == [ + o.data + for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc), + ) + ] + + assert ["2", "3"] == [ + o.data + for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki), + ) + ] @execute_count(8) def test_success_case(self): - """ Test that the min and max time uuid functions work as expected """ + """Test that the min and max time uuid functions work as expected""" pk = uuid4() startpoint = datetime.utcnow() - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=1)), data='1') - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=2)), data='2') + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(startpoint + timedelta(seconds=1)), + data="1", + ) + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(startpoint + timedelta(seconds=2)), + data="2", + ) midpoint = startpoint + timedelta(seconds=3) - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=4)), data='3') - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=5)), data='4') + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(startpoint + timedelta(seconds=4)), + data="3", + ) + TimeUUIDQueryModel.create( + partition=pk, + time=uuid_from_time(startpoint + timedelta(seconds=5)), + data="4", + ) # test kwarg filtering - q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) + q = TimeUUIDQueryModel.filter( + partition=pk, time__lte=functions.MaxTimeUUID(midpoint) + ) q = [d for d in q] assert len(q) == 2, "Got: %s" % q datas = [d.data for d in q] - assert '1' in datas - assert '2' in datas + assert "1" in datas + assert "2" in datas - q = TimeUUIDQueryModel.filter(partition=pk, time__gte=functions.MinTimeUUID(midpoint)) + q = TimeUUIDQueryModel.filter( + partition=pk, time__gte=functions.MinTimeUUID(midpoint) + ) assert len(q) == 2 datas = [d.data for d in q] - assert '3' in datas - assert '4' in datas + assert "3" in datas + assert "4" in datas # test query expression filtering q = TimeUUIDQueryModel.filter( TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint) + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint), ) q = [d for d in q] assert len(q) == 2 datas = [d.data for d in q] - assert '1' in datas - assert '2' in datas + assert "1" in datas + assert "2" in datas q = TimeUUIDQueryModel.filter( TimeUUIDQueryModel.partition == pk, - TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint) + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint), ) assert len(q) == 2 datas = [d.data for d in q] - assert '3' in datas - assert '4' in datas + assert "3" in datas + assert "4" in datas @requires_collection_indexes class TestInOperator(BaseQuerySetUsage): @execute_count(1) def test_kwarg_success_case(self): - """ Tests the in operator works with the kwarg query method """ + """Tests the in operator works with the kwarg query method""" q = TestModel.filter(test_id__in=[0, 1]) assert q.count() == 8 @execute_count(1) def test_query_expression_success_case(self): - """ Tests the in operator works with the query expression query method """ + """Tests the in operator works with the query expression query method""" q = TestModel.filter(TestModel.test_id.in_([0, 1])) assert q.count() == 8 @@ -970,10 +1234,12 @@ def test_bool(self): @test_category object_mapper """ + class bool_model(Model): k = columns.Integer(primary_key=True) b = columns.Boolean(primary_key=True) v = columns.Integer(default=3) + sync_table(bool_model) bool_model.create(k=0, b=True) @@ -993,25 +1259,26 @@ def test_bool_filter(self): @test_category object_mapper """ + class bool_model2(Model): k = columns.Boolean(primary_key=True) b = columns.Integer(primary_key=True) v = columns.Text() + drop_table(bool_model2) sync_table(bool_model2) - bool_model2.create(k=True, b=1, v='a') - bool_model2.create(k=False, b=1, v='b') + bool_model2.create(k=True, b=1, v="a") + bool_model2.create(k=False, b=1, v="b") assert len(list(bool_model2.objects(k__in=(True, False)))) == 2 @greaterthancass20 @requires_collection_indexes class TestContainsOperator(BaseQuerySetUsage): - @execute_count(6) def test_kwarg_success_case(self): - """ Tests the CONTAINS operator works with the kwarg query method """ + """Tests the CONTAINS operator works with the kwarg query method""" q = IndexedCollectionsTestModel.filter(test_list__contains=1) assert q.count() == 2 @@ -1042,46 +1309,65 @@ def test_kwarg_success_case(self): @execute_count(6) def test_query_expression_success_case(self): - """ Tests the CONTAINS operator works with the query expression query method """ - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(1)) + """Tests the CONTAINS operator works with the query expression query method""" + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_list.contains_(1) + ) assert q.count() == 2 - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(13)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_list.contains_(13) + ) assert q.count() == 0 - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(3)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_set.contains_(3) + ) assert q.count() == 2 - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(13)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_set.contains_(13) + ) assert q.count() == 0 - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(42)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_map.contains_(42) + ) assert q.count() == 1 - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(13)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_map.contains_(13) + ) assert q.count() == 0 with pytest.raises(QueryException): - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_map_no_index.contains_(1) + ) assert q.count() == 0 with pytest.raises(QueryException): - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_map_no_index.contains_(1) + ) assert q.count() == 0 with pytest.raises(QueryException): - q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) + q = IndexedCollectionsTestModel.filter( + IndexedCollectionsTestModel.test_map_no_index.contains_(1) + ) assert q.count() == 0 @requires_collection_indexes class TestValuesList(BaseQuerySetUsage): - @execute_count(2) def test_values_list(self): q = TestModel.objects.filter(test_id=0, attempt_id=1) - item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first() - assert item == [0, 1, 'try2', 10, 30] + item = q.values_list( + "test_id", "attempt_id", "description", "expected_result", "test_result" + ).first() + assert item == [0, 1, "try2", 10, 30] - item = q.values_list('expected_result', flat=True).first() + item = q.values_list("expected_result", flat=True).first() assert item == 10 @@ -1098,19 +1384,24 @@ class PageQueryTests(BaseCassEngTestCase): @execute_count(3) def test_paged_result_handling(self): if PROTOCOL_VERSION < 2: - raise unittest.SkipTest("Paging requires native protocol 2+, currently using: {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Paging requires native protocol 2+, currently using: {0}".format( + PROTOCOL_VERSION + ) + ) # addresses #225 class PagingTest(Model): id = columns.Integer(primary_key=True) val = columns.Integer() + sync_table(PagingTest) PagingTest.create(id=1, val=1) PagingTest.create(id=2, val=2) session = get_session() - with mock.patch.object(session, 'default_fetch_size', 1): + with mock.patch.object(session, "default_fetch_size", 1): results = PagingTest.objects()[:] assert len(results) == 2 @@ -1119,41 +1410,41 @@ class PagingTest(Model): @requires_collection_indexes class ModelQuerySetTimeoutTestCase(BaseQuerySetUsage): def test_default_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: list(TestModel.objects()) - assert mock_execute.call_args[-1]['timeout'] == NOT_SET + assert mock_execute.call_args[-1]["timeout"] == NOT_SET def test_float_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: list(TestModel.objects().timeout(0.5)) - assert mock_execute.call_args[-1]['timeout'] == 0.5 + assert mock_execute.call_args[-1]["timeout"] == 0.5 def test_none_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: list(TestModel.objects().timeout(None)) - assert mock_execute.call_args[-1]['timeout'] == None + assert mock_execute.call_args[-1]["timeout"] == None @requires_collection_indexes class DMLQueryTimeoutTestCase(BaseQuerySetUsage): def setUp(self): - self.model = TestModel(test_id=1, attempt_id=1, description='timeout test') + self.model = TestModel(test_id=1, attempt_id=1, description="timeout test") super(DMLQueryTimeoutTestCase, self).setUp() def test_default_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: self.model.save() - assert mock_execute.call_args[-1]['timeout'] == NOT_SET + assert mock_execute.call_args[-1]["timeout"] == NOT_SET def test_float_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: self.model.timeout(0.5).save() - assert mock_execute.call_args[-1]['timeout'] == 0.5 + assert mock_execute.call_args[-1]["timeout"] == 0.5 def test_none_timeout(self): - with mock.patch.object(Session, 'execute') as mock_execute: + with mock.patch.object(Session, "execute") as mock_execute: self.model.timeout(None).save() - assert mock_execute.call_args[-1]['timeout'] == None + assert mock_execute.call_args[-1]["timeout"] == None def test_timeout_then_batch(self): b = query.BatchQuery() @@ -1169,31 +1460,30 @@ def test_batch_then_timeout(self): class DBFieldModel(Model): - k0 = columns.Integer(partition_key=True, db_field='a') - k1 = columns.Integer(partition_key=True, db_field='b') - c0 = columns.Integer(primary_key=True, db_field='c') - v0 = columns.Integer(db_field='d') - v1 = columns.Integer(db_field='e', index=True) + k0 = columns.Integer(partition_key=True, db_field="a") + k1 = columns.Integer(partition_key=True, db_field="b") + c0 = columns.Integer(primary_key=True, db_field="c") + v0 = columns.Integer(db_field="d") + v1 = columns.Integer(db_field="e", index=True) class DBFieldModelMixed1(Model): - k0 = columns.Integer(partition_key=True, db_field='a') + k0 = columns.Integer(partition_key=True, db_field="a") k1 = columns.Integer(partition_key=True) - c0 = columns.Integer(primary_key=True, db_field='c') - v0 = columns.Integer(db_field='d') + c0 = columns.Integer(primary_key=True, db_field="c") + v0 = columns.Integer(db_field="d") v1 = columns.Integer(index=True) class DBFieldModelMixed2(Model): k0 = columns.Integer(partition_key=True) - k1 = columns.Integer(partition_key=True, db_field='b') + k1 = columns.Integer(partition_key=True, db_field="b") c0 = columns.Integer(primary_key=True) - v0 = columns.Integer(db_field='d') - v1 = columns.Integer(index=True, db_field='e') + v0 = columns.Integer(db_field="d") + v1 = columns.Integer(index=True, db_field="e") class TestModelQueryWithDBField(BaseCassEngTestCase): - def setUp(cls): super(TestModelQueryWithDBField, cls).setUpClass() cls.model_list = [DBFieldModel, DBFieldModelMixed1, DBFieldModelMixed2] @@ -1217,15 +1507,15 @@ def test_basic_crud(self): @test_category object_mapper """ for model in self.model_list: - values = {'k0': 1, 'k1': 2, 'c0': 3, 'v0': 4, 'v1': 5} + values = {"k0": 1, "k1": 2, "c0": 3, "v0": 4, "v1": 5} # create i = model.create(**values) i = model.objects(k0=i.k0, k1=i.k1).first() assert i == model(**values) # create - values['v0'] = 101 - i.update(v0=values['v0']) + values["v0"] = 101 + i.update(v0=values["v0"]) i = model.objects(k0=i.k0, k1=i.k1).first() assert i == model(**values) @@ -1254,16 +1544,20 @@ def test_slice(self): @test_category object_mapper """ for model in self.model_list: - values = {'k0': 1, 'k1': 3, 'c0': 3, 'v0': 4, 'v1': 5} + values = {"k0": 1, "k1": 3, "c0": 3, "v0": 4, "v1": 5} clustering_values = range(3) for c in clustering_values: - values['c0'] = c + values["c0"] = c i = model.create(**values) assert model.objects(k0=i.k0, k1=i.k1).count() == len(clustering_values) assert model.objects(k0=i.k0, k1=i.k1, c0=i.c0).count() == 1 - assert model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count() == len(clustering_values[:-1]) - assert model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count() == len(clustering_values[1:]) + assert model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count() == len( + clustering_values[:-1] + ) + assert model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count() == len( + clustering_values[1:] + ) @execute_count(15) def test_order(self): @@ -1277,13 +1571,19 @@ def test_order(self): @test_category object_mapper """ for model in self.model_list: - values = {'k0': 1, 'k1': 4, 'c0': 3, 'v0': 4, 'v1': 5} + values = {"k0": 1, "k1": 4, "c0": 3, "v0": 4, "v1": 5} clustering_values = range(3) for c in clustering_values: - values['c0'] = c + values["c0"] = c i = model.create(**values) - assert model.objects(k0=i.k0, k1=i.k1).order_by('c0').first().c0 == clustering_values[0] - assert model.objects(k0=i.k0, k1=i.k1).order_by('-c0').first().c0 == clustering_values[-1] + assert ( + model.objects(k0=i.k0, k1=i.k1).order_by("c0").first().c0 + == clustering_values[0] + ) + assert ( + model.objects(k0=i.k0, k1=i.k1).order_by("-c0").first().c0 + == clustering_values[-1] + ) @execute_count(15) def test_index(self): @@ -1297,11 +1597,11 @@ def test_index(self): @test_category object_mapper """ for model in self.model_list: - values = {'k0': 1, 'k1': 5, 'c0': 3, 'v0': 4, 'v1': 5} + values = {"k0": 1, "k1": 5, "c0": 3, "v0": 4, "v1": 5} clustering_values = range(3) for c in clustering_values: - values['c0'] = c - values['v1'] = c + values["c0"] = c + values["v1"] = c i = model.create(**values) assert model.objects(k0=i.k0, k1=i.k1).count() == len(clustering_values) assert model.objects(k0=i.k0, k1=i.k1, v1=0).count() == 1 @@ -1318,7 +1618,7 @@ def test_db_field_names_used(self): @test_category object_mapper """ - values = ('k0', 'k1', 'c0', 'v0', 'v1') + values = ("k0", "k1", "c0", "v0", "v1") # Test QuerySet Path b = BatchQuery() DBFieldModel.objects(k0=1).batch(b).update( @@ -1341,10 +1641,15 @@ def test_db_field_names_used(self): def test_db_field_value_list(self): DBFieldModel.create(k0=0, k1=0, c0=0, v0=4, v1=5) - assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._defer_fields == {'a', 'c', 'b'} - assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._only_fields == ['c', 'd'] + assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list( + "c0", "v0" + )._defer_fields == {"a", "c", "b"} + assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list( + "c0", "v0" + )._only_fields == ["c", "d"] + + list(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list("c0", "v0")) - list(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')) class TestModelSmall(Model): __test__ = False @@ -1432,6 +1737,7 @@ class TestModelQueryWithDifferedFeld(BaseCassEngTestCase): @test_category object_mapper """ + @classmethod def setUpClass(cls): super(TestModelQueryWithDifferedFeld, cls).setUpClass() @@ -1445,13 +1751,21 @@ def tearDownClass(cls): @execute_count(8) def test_defaultFetchSize(self): # Populate Table - People.objects.create(last_name="Smith", first_name="John", birthday=datetime.now()) - People.objects.create(last_name="Bestwater", first_name="Alan", birthday=datetime.now()) - People.objects.create(last_name="Smith", first_name="Greg", birthday=datetime.now()) - People.objects.create(last_name="Smith", first_name="Adam", birthday=datetime.now()) + People.objects.create( + last_name="Smith", first_name="John", birthday=datetime.now() + ) + People.objects.create( + last_name="Bestwater", first_name="Alan", birthday=datetime.now() + ) + People.objects.create( + last_name="Smith", first_name="Greg", birthday=datetime.now() + ) + People.objects.create( + last_name="Smith", first_name="Adam", birthday=datetime.now() + ) # Check query constructions - expected_fields = ['first_name', 'birthday'] + expected_fields = ["first_name", "birthday"] assert People.filter(last_name="Smith")._select_fields() == expected_fields # Validate correct fields are fetched smiths = list(People.filter(last_name="Smith")) @@ -1462,11 +1776,21 @@ def test_defaultFetchSize(self): sync_table(People2) # populate new format - People2.objects.create(last_name="Smith", first_name="Chris", middle_name="Raymond", birthday=datetime.now()) - People2.objects.create(last_name="Smith", first_name="Andrew", middle_name="Micheal", birthday=datetime.now()) + People2.objects.create( + last_name="Smith", + first_name="Chris", + middle_name="Raymond", + birthday=datetime.now(), + ) + People2.objects.create( + last_name="Smith", + first_name="Andrew", + middle_name="Micheal", + birthday=datetime.now(), + ) # validate query construction - expected_fields = ['first_name', 'middle_name', 'birthday'] + expected_fields = ["first_name", "middle_name", "birthday"] assert People2.filter(last_name="Smith")._select_fields() == expected_fields # validate correct items are returneds From 2c79bf80d6dabe2be06852c390aab688d737e1fa Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:09:56 +0200 Subject: [PATCH 10/18] remove: dead WeakSet fallback and custom implementation WeakSet has been available in the weakref module since Python 2.7+ and all Python 3 versions. The try/except ImportError fallback to cassandra.util.WeakSet was unreachable dead code on Python 3. - Replace try/except with direct 'from weakref import WeakSet' in cluster.py, pool.py, and io/asyncorereactor.py - Delete the ~210-line custom WeakSet class and its _IterationGuard helper from cassandra/util.py - Remove the now-unused 'from _weakref import ref' import --- cassandra/cluster.py | 5 +- cassandra/io/asyncorereactor.py | 75 ++-- cassandra/pool.py | 263 ++++++++--- cassandra/util.py | 755 +++++++++++++++----------------- 4 files changed, 589 insertions(+), 509 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 547290ff0f..a7f2b98a10 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -171,10 +171,7 @@ # TODO: remove it when eventlet issue would be fixed EventletConnection = None -try: - from weakref import WeakSet -except ImportError: - from cassandra.util import WeakSet # NOQA +from weakref import WeakSet def _is_gevent_monkey_patched(): diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 02466ad0d2..b43633d352 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -24,12 +24,10 @@ import sys import ssl -try: - from weakref import WeakSet -except ImportError: - from cassandra.util import WeakSet # noqa +from weakref import WeakSet from cassandra import DependencyException + try: import asyncore except ModuleNotFoundError: @@ -39,14 +37,20 @@ "other event loop implementations." ) -from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager - +from cassandra.connection import ( + Connection, + ConnectionShutdown, + NONBLOCKING, + Timer, + TimerManager, +) log = logging.getLogger(__name__) _dispatcher_map = {} + def _cleanup(loop): if loop: loop._cleanup() @@ -80,7 +84,6 @@ def wait(self, timeout=None): class _PipeWrapper(object): - def __init__(self, fd): self.fd = fd @@ -98,7 +101,6 @@ def getsockopt(self, level, optname, buflen=None): class _AsyncoreDispatcher(asyncore.dispatcher): - def __init__(self, socket): asyncore.dispatcher.__init__(self, map=_dispatcher_map) # inject after to avoid base class validation @@ -120,7 +122,6 @@ def loop(self, timeout): class _AsyncorePipeDispatcher(_AsyncoreDispatcher): - def __init__(self): self.read_fd, self.write_fd = os.pipe() _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) @@ -136,7 +137,7 @@ def handle_read(self): def notify_loop(self): if not self._notified: self._notified = True - os.write(self.write_fd, b'x') + os.write(self.write_fd, b"x") class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): @@ -152,7 +153,8 @@ class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher """ - bind_address = ('localhost', 10000) + + bind_address = ("localhost", 10000) def __init__(self): self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -172,14 +174,13 @@ def handle_read(self): def notify_loop(self): if not self._notified: self._notified = True - self._socket.sendto(b'', self.bind_address) + self._socket.sendto(b"", self.bind_address) def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) class _BusyWaitDispatcher(object): - max_write_latency = 0.001 """ Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check @@ -193,7 +194,12 @@ def loop(self, timeout): if not _dispatcher_map: time.sleep(0.005) count = timeout // self.max_write_latency - asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) + asyncore.loop( + timeout=self.max_write_latency, + use_poll=True, + map=_dispatcher_map, + count=count, + ) def validate(self): pass @@ -203,10 +209,11 @@ def close(self): class AsyncoreLoop(object): - timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts - _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher + _loop_dispatch_class = ( + _AsyncorePipeDispatcher if os.name != "nt" else _BusyWaitDispatcher + ) def __init__(self): self._pid = os.getpid() @@ -223,7 +230,10 @@ def __init__(self): dispatcher.validate() log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) except Exception: - log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) + log.exception( + "Failed validating loop dispatch with %s. Using busy wait execution instead.", + self._loop_dispatch_class, + ) dispatcher.close() dispatcher = _BusyWaitDispatcher() self._loop_dispatcher = dispatcher @@ -241,7 +251,9 @@ def maybe_start(self): self._loop_lock.release() if should_start: - self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop") + self._thread = Thread( + target=self._run_loop, name="asyncore_cassandra_driver_event_loop" + ) self._thread.daemon = True self._thread.start() @@ -256,7 +268,9 @@ def _run_loop(self): self._loop_dispatcher.loop(self.timer_resolution) self._timers.service_timeouts() except Exception as exc: - self._maybe_log_debug("Asyncore event loop stopped unexpectedly", exc_info=exc) + self._maybe_log_debug( + "Asyncore event loop stopped unexpectedly", exc_info=exc + ) break self._started = False @@ -291,7 +305,8 @@ def _cleanup(self): if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " - "Please call Cluster.shutdown() to avoid this.") + "Please call Cluster.shutdown() to avoid this." + ) log.debug("Event loop thread was joined") @@ -358,8 +373,9 @@ def __init__(self, *args, **kwargs): init_handler = WaitableTimer( timeout=0, - callback=partial(asyncore.dispatcher.__init__, - self, self._socket, _dispatcher_map) + callback=partial( + asyncore.dispatcher.__init__, self, self._socket, _dispatcher_map + ), ) _global_loop.add_timer(init_handler) init_handler.wait(kwargs["connect_timeout"]) @@ -390,7 +406,7 @@ def close(self): msg += ": %s" % (self.last_error,) self.error_all_requests(ConnectionShutdown(msg)) - #This happens when the connection is shutdown while waiting for the ReadyMessage + # This happens when the connection is shutdown while waiting for the ReadyMessage if not self.connected_event.is_set(): self.last_error = ConnectionShutdown(msg) @@ -417,8 +433,10 @@ def handle_write(self): sent = self.send(next_msg) self._readable = True except socket.error as err: - if (err.args[0] in NONBLOCKING or - err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): + if err.args[0] in NONBLOCKING or err.args[0] in ( + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ): with self.deque_lock: self.deque.appendleft(next_msg) else: @@ -463,7 +481,7 @@ def push(self, data): if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): - chunks.append(data[i:i + sabs]) + chunks.append(data[i : i + sabs]) else: chunks = [data] @@ -476,4 +494,7 @@ def writable(self): return self._writable def readable(self): - return self._readable or ((self.is_control_connection or self._continuous_paging_sessions) and not (self.is_defunct or self.is_closed)) + return self._readable or ( + (self.is_control_connection or self._continuous_paging_sessions) + and not (self.is_defunct or self.is_closed) + ) diff --git a/cassandra/pool.py b/cassandra/pool.py index 2da657256f..cf9e262665 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -15,6 +15,7 @@ """ Connection pooling and host management. """ + from concurrent.futures import Future from functools import total_ordering import logging @@ -25,10 +26,7 @@ import uuid from threading import Lock, RLock, Condition import weakref -try: - from weakref import WeakSet -except ImportError: - from cassandra.util import WeakSet # NOQA +from weakref import WeakSet from cassandra import AuthenticationFailed from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint @@ -42,6 +40,7 @@ class NoConnectionsAvailable(Exception): All existing connections to a given host are busy, or there are no open connections. """ + pass @@ -167,13 +166,22 @@ class Host(object): sharding_info = None - def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): + def __init__( + self, + endpoint, + conviction_policy_factory, + datacenter=None, + rack=None, + host_id=None, + ): if endpoint is None: raise ValueError("endpoint may not be None") if conviction_policy_factory is None: raise ValueError("conviction_policy_factory may not be None") - self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) + self.endpoint = ( + endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) + ) self.conviction_policy = conviction_policy_factory(self) if not host_id: raise ValueError("host_id may not be None") @@ -191,12 +199,12 @@ def address(self): @property def datacenter(self): - """ The datacenter the node is in. """ + """The datacenter the node is in.""" return self._datacenter @property def rack(self): - """ The rack the node is in. """ + """The rack the node is in.""" return self._rack def set_location_info(self, datacenter, rack): @@ -261,7 +269,9 @@ class _ReconnectionHandler(object): _cancelled = False - def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwargs): + def __init__( + self, scheduler, schedule, callback, *callback_args, **callback_kwargs + ): self.scheduler = scheduler self.schedule = schedule self.callback = callback @@ -295,7 +305,8 @@ def run(self): if next_delay is None: log.warning( "Will not continue to retry reconnection attempts " - "due to an exhausted retry schedule") + "due to an exhausted retry schedule" + ) else: self.scheduler.schedule(next_delay, self.run) else: @@ -343,8 +354,9 @@ def on_exception(self, exc, next_delay): class _HostReconnectionHandler(_ReconnectionHandler): - - def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): + def __init__( + self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs + ): _ReconnectionHandler.__init__(self, *args, **kwargs) self.is_host_addition = is_host_addition self.on_add = on_add @@ -356,7 +368,10 @@ def try_reconnect(self): return self.connection_factory() def on_reconnection(self, connection): - log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) + log.info( + "Successful reconnection to %s, marking node up if it isn't already", + self.host, + ) if self.is_host_addition: self.on_add(self.host) else: @@ -366,8 +381,12 @@ def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: - log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", - self.host, next_delay, exc) + log.warning( + "Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", + self.host, + next_delay, + exc, + ) log.debug("Reconnection error details", exc_info=True) return True @@ -423,43 +442,67 @@ def __init__(self, host, host_distance, session): if host_distance == HostDistance.IGNORED: log.debug("Not opening connection to ignored host %s", self.host) return - elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts: + elif ( + host_distance == HostDistance.REMOTE + and not session.cluster.connect_to_remote_hosts + ): log.debug("Not opening connection to remote host %s", self.host) return log.debug("Initializing connection for host %s", self.host) - first_connection = session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) - log.debug("First connection created to %s for shard_id=%i", self.host, first_connection.features.shard_id) + first_connection = session.cluster.connection_factory( + self.host.endpoint, + on_orphaned_stream_released=self.on_orphaned_stream_released, + ) + log.debug( + "First connection created to %s for shard_id=%i", + self.host, + first_connection.features.shard_id, + ) self._connections[first_connection.features.shard_id] = first_connection self._keyspace = session.keyspace if self._keyspace: first_connection.set_keyspace_blocking(self._keyspace) - if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: + if ( + first_connection.features.sharding_info + and not self._session.cluster.shard_aware_options.disable + ): self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) self.tablets_routing_v1 = first_connection.features.tablets_routing_v1 log.debug("Finished initializing connection for host %s", self.host) - def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None): + def _get_connection_for_routing_key( + self, routing_key=None, keyspace=None, table=None + ): if self.is_shutdown: raise ConnectionException( - "Pool for %s is shutdown" % (self.host,), self.host) + "Pool for %s is shutdown" % (self.host,), self.host + ) if not self._connections: raise NoConnectionsAvailable() shard_id = None - if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key: - t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key) - + if ( + not self._session.cluster.shard_aware_options.disable + and self.host.sharding_info + and routing_key + ): + t = self._session.cluster.metadata.token_map.token_class.from_key( + routing_key + ) + shard_id = None if self.tablets_routing_v1 and table is not None: if keyspace is None: keyspace = self._keyspace - tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t) + tablet = self._session.cluster.metadata._tablets.get_tablet_for_key( + keyspace, table, t + ) if tablet is not None: for replica in tablet.replicas: @@ -480,20 +523,22 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table "Using connection to shard_id=%i on host %s for routing_key=%s", shard_id, self.host, - routing_key + routing_key, ) if conn.orphaned_threshold_reached and shard_id not in self._connecting: # The connection has met its orphaned stream ID limit # and needs to be replaced. Start opening a connection # to the same shard and replace when it is opened. self._connecting.add(shard_id) - self._session.submit(self._open_connection_to_missing_shard, shard_id) + self._session.submit( + self._open_connection_to_missing_shard, shard_id + ) log.debug( "Connection to shard_id=%i reached orphaned stream limit, replacing on host %s (%s/%i)", shard_id, self.host, len(self._connections), - self.host.sharding_info.shards_count + self.host.sharding_info.shards_count, ) elif shard_id not in self._connecting: # rate controlled optimistic attempt to connect to a missing shard @@ -504,12 +549,14 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table shard_id, self.host, len(self._connections), - self.host.sharding_info.shards_count + self.host.sharding_info.shards_count, ) if conn and not conn.is_closed: return conn - active_connections = [conn for conn in list(self._connections.values()) if not conn.is_closed] + active_connections = [ + conn for conn in list(self._connections.values()) if not conn.is_closed + ] if active_connections: return random.choice(active_connections) return random.choice(list(self._connections.values())) @@ -522,9 +569,13 @@ def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None while True: if conn.is_closed: # The connection might have been closed in the meantime - if so, try again - conn = self._get_connection_for_routing_key(routing_key, keyspace, table) + conn = self._get_connection_for_routing_key( + routing_key, keyspace, table + ) with conn.lock: - if (not conn.is_closed or last_retry) and conn.in_flight < conn.max_request_id: + if ( + not conn.is_closed or last_retry + ) and conn.in_flight < conn.max_request_id: # On last retry we ignore connection status, since it is better to return closed connection than # raise Exception conn.in_flight += 1 @@ -558,8 +609,12 @@ def return_connection(self, connection, stream_was_orphaned=False): is_down = False if not connection.signaled_error: - log.debug("Defunct or closed connection (%s) returned to pool, potentially " - "marking host %s as down", id(connection), self.host) + log.debug( + "Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", + id(connection), + self.host, + ) is_down = self.host.signal_connection_failure(connection.last_error) connection.signaled_error = True @@ -581,7 +636,9 @@ def return_connection(self, connection, stream_was_orphaned=False): self._session.submit(self._replace, connection) elif connection in self._trash: with connection.lock: - no_pending_requests = connection.in_flight <= len(connection.orphaned_request_ids) + no_pending_requests = connection.in_flight <= len( + connection.orphaned_request_ids + ) if no_pending_requests: with self._lock: close_connection = False @@ -589,7 +646,11 @@ def return_connection(self, connection, stream_was_orphaned=False): self._trash.remove(connection) close_connection = True if close_connection: - log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) + log.debug( + "Closing trashed connection (%s) to %s", + id(connection), + self.host, + ) connection.close() def on_orphaned_stream_released(self): @@ -609,12 +670,20 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if ( + self.host.sharding_info + and not self._session.cluster.shard_aware_options.disable + ): self._connecting.add(connection.features.shard_id) - self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) + self._session.submit( + self._open_connection_to_missing_shard, + connection.features.shard_id, + ) else: - connection = self._session.cluster.connection_factory(self.host.endpoint, - on_orphaned_stream_released=self.on_orphaned_stream_released) + connection = self._session.cluster.connection_factory( + self.host.endpoint, + on_orphaned_stream_released=self.on_orphaned_stream_released, + ) if self._keyspace: connection.set_keyspace_blocking(self._keyspace) self._connections[connection.features.shard_id] = connection @@ -652,7 +721,9 @@ def shutdown(self): connection.close() for connection in pending_connections_to_close: - log.debug("Closing pending connection (%s) to %s", id(connection), self.host) + log.debug( + "Closing pending connection (%s) to %s", id(connection), self.host + ) connection.close() self._close_excess_connections() @@ -679,16 +750,26 @@ def _close_excess_connections(self): c.close() def disable_advanced_shard_aware(self, secs): - log.warning("disabling advanced_shard_aware for %i seconds, could be that this client is behind NAT?", secs) - self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until) + log.warning( + "disabling advanced_shard_aware for %i seconds, could be that this client is behind NAT?", + secs, + ) + self.advanced_shardaware_block_until = max( + time.time() + secs, self.advanced_shardaware_block_until + ) def _get_shard_aware_endpoint(self): - if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time()) or \ - self._session.cluster.shard_aware_options.disable_shardaware_port: + if ( + self.advanced_shardaware_block_until + and self.advanced_shardaware_block_until < time.time() + ) or self._session.cluster.shard_aware_options.disable_shardaware_port: return None endpoint = None - if self._session.cluster.ssl_options and self.host.sharding_info.shard_aware_port_ssl: + if ( + self._session.cluster.ssl_options + and self.host.sharding_info.shard_aware_port_ssl + ): endpoint = copy.copy(self.host.endpoint) endpoint._port = self.host.sharding_info.shard_aware_port_ssl elif self.host.sharding_info.shard_aware_port: @@ -722,23 +803,41 @@ def _open_connection_to_missing_shard(self, shard_id): log.debug("shard_aware_endpoint=%r", shard_aware_endpoint) if shard_aware_endpoint: try: - conn = self._session.cluster.connection_factory(shard_aware_endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released, - shard_id=shard_id, - total_shards=self.host.sharding_info.shards_count) + conn = self._session.cluster.connection_factory( + shard_aware_endpoint, + host_conn=self, + on_orphaned_stream_released=self.on_orphaned_stream_released, + shard_id=shard_id, + total_shards=self.host.sharding_info.shards_count, + ) conn.original_endpoint = self.host.endpoint except Exception as exc: - log.error("Failed to open connection to %s, on shard_id=%i: %s", self.host, shard_id, exc) + log.error( + "Failed to open connection to %s, on shard_id=%i: %s", + self.host, + shard_id, + exc, + ) raise else: - conn = self._session.cluster.connection_factory(self.host.endpoint, host_conn=self, on_orphaned_stream_released=self.on_orphaned_stream_released) + conn = self._session.cluster.connection_factory( + self.host.endpoint, + host_conn=self, + on_orphaned_stream_released=self.on_orphaned_stream_released, + ) log.debug( "Received a connection %s for shard_id=%i on host %s", id(conn), conn.features.shard_id if conn.features.shard_id is not None else -1, - self.host) + self.host, + ) if self.is_shutdown: - log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", self.host, id(conn)) + log.debug( + "Pool for host %s is in shutdown, closing the new connection (%s)", + self.host, + id(conn), + ) conn.close() return @@ -753,7 +852,7 @@ def _open_connection_to_missing_shard(self, shard_id): "New connection (%s) created to shard_id=%i on host %s", id(conn), conn.features.shard_id, - self.host + self.host, ) old_conn = None with self._lock: @@ -767,7 +866,7 @@ def _open_connection_to_missing_shard(self, shard_id): id(old_conn), id(conn), conn.features.shard_id, - self.host + self.host, ) if self._keyspace: conn.set_keyspace_blocking(self._keyspace) @@ -784,7 +883,7 @@ def _open_connection_to_missing_shard(self, shard_id): "Immediately closing the old connection (%s) for shard %i on host %s", id(old_conn), old_conn.features.shard_id, - self.host + self.host, ) old_conn.close() else: @@ -807,20 +906,23 @@ def _open_connection_to_missing_shard(self, shard_id): len(self._connections), self.host.sharding_info.shards_count, self.host, - num_missing_or_needing_replacement + num_missing_or_needing_replacement, ) if num_missing_or_needing_replacement == 0: log.debug( "All shards of host %s have at least one connection, closing %i excess connections", self.host, - len(self._excess_connections) + len(self._excess_connections), ) self._close_excess_connections() - elif self.host.sharding_info.shards_count == len(self._connections) and self.num_missing_or_needing_replacement == 0: + elif ( + self.host.sharding_info.shards_count == len(self._connections) + and self.num_missing_or_needing_replacement == 0 + ): log.debug( "All shards are already covered, closing newly opened excess connection %s for host %s", id(self), - self.host + self.host, ) conn.close() else: @@ -830,7 +932,7 @@ def _open_connection_to_missing_shard(self, shard_id): id(conn), self._excess_connection_limit, self.host, - len(self._excess_connections) + len(self._excess_connections), ) self._close_excess_connections() @@ -838,7 +940,7 @@ def _open_connection_to_missing_shard(self, shard_id): "Putting a connection %s to shard %i to the excess pool of host %s", id(conn), conn.features.shard_id, - self.host + self.host, ) close_connection = False with self._lock: @@ -861,7 +963,9 @@ def _open_connections_for_all_shards(self, skip_shard_id=None): for shard_id in range(self.host.sharding_info.shards_count): if skip_shard_id is not None and skip_shard_id == shard_id: continue - future = self._session.submit(self._open_connection_to_missing_shard, shard_id) + future = self._session.submit( + self._open_connection_to_missing_shard, shard_id + ) if isinstance(future, Future): self._connecting.add(shard_id) self._shard_connections_futures.append(future) @@ -910,21 +1014,36 @@ def get_connections(self): def get_state(self): in_flights = [c.in_flight for c in list(self._connections.values())] - orphan_requests = [c.orphaned_request_ids for c in list(self._connections.values())] - return {'shutdown': self.is_shutdown, 'open_count': self.open_count, \ - 'in_flights': in_flights, 'orphan_requests': orphan_requests} + orphan_requests = [ + c.orphaned_request_ids for c in list(self._connections.values()) + ] + return { + "shutdown": self.is_shutdown, + "open_count": self.open_count, + "in_flights": in_flights, + "orphan_requests": orphan_requests, + } @property def num_missing_or_needing_replacement(self): - return self.host.sharding_info.shards_count \ - - sum(1 for c in list(self._connections.values()) if not c.orphaned_threshold_reached) + return self.host.sharding_info.shards_count - sum( + 1 + for c in list(self._connections.values()) + if not c.orphaned_threshold_reached + ) @property def open_count(self): - return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in list(self._connections.values())]) + return sum( + [ + 1 if c and not (c.is_closed or c.is_defunct) else 0 + for c in list(self._connections.values()) + ] + ) @property def _excess_connection_limit(self): - return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - + return ( + self.host.sharding_info.shards_count + * self.max_excess_connections_per_shard_multiplier + ) diff --git a/cassandra/util.py b/cassandra/util.py index 593c264033..4f5e9411b8 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _weakref import ref import calendar from collections import OrderedDict from collections.abc import Mapping @@ -40,14 +39,16 @@ from cassandra import DriverException DATETIME_EPOC = datetime.datetime(1970, 1, 1).replace(tzinfo=None) -UTC_DATETIME_EPOC = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) +UTC_DATETIME_EPOC = datetime.datetime.fromtimestamp( + 0, tz=datetime.timezone.utc +).replace(tzinfo=None) -_nan = float('nan') +_nan = float("nan") log = logging.getLogger(__name__) -assert sys.byteorder in ('little', 'big') -is_little_endian = sys.byteorder == 'little' +assert sys.byteorder in ("little", "big") +is_little_endian = sys.byteorder == "little" def datetime_from_timestamp(timestamp): @@ -121,7 +122,9 @@ def min_uuid_from_time(timestamp): See :func:`uuid_from_time` for argument and return types. """ - return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) + return uuid_from_time( + timestamp, 0x808080808080, 0x80 + ) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) def max_uuid_from_time(timestamp): @@ -130,7 +133,9 @@ def max_uuid_from_time(timestamp): See :func:`uuid_from_time` for argument and return types. """ - return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127) + return uuid_from_time( + timestamp, 0x7F7F7F7F7F7F, 0x3F7F + ) # Max signed bytes (0x7f = 127) def uuid_from_time(time_arg, node=None, clock_seq=None): @@ -156,7 +161,7 @@ def uuid_from_time(time_arg, node=None, clock_seq=None): :rtype: :class:`uuid.UUID` """ - if hasattr(time_arg, 'utctimetuple'): + if hasattr(time_arg, "utctimetuple"): seconds = int(calendar.timegm(time_arg.utctimetuple())) microseconds = (seconds * 1e6) + time_arg.time().microsecond else: @@ -164,31 +169,41 @@ def uuid_from_time(time_arg, node=None, clock_seq=None): # 0x01b21dd213814000 is the number of 100-ns intervals between the # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. - intervals = int(microseconds * 10) + 0x01b21dd213814000 + intervals = int(microseconds * 10) + 0x01B21DD213814000 - time_low = intervals & 0xffffffff - time_mid = (intervals >> 32) & 0xffff - time_hi_version = (intervals >> 48) & 0x0fff + time_low = intervals & 0xFFFFFFFF + time_mid = (intervals >> 32) & 0xFFFF + time_hi_version = (intervals >> 48) & 0x0FFF if clock_seq is None: clock_seq = random.getrandbits(14) else: - if clock_seq > 0x3fff: - raise ValueError('clock_seq is out of range (need a 14-bit value)') + if clock_seq > 0x3FFF: + raise ValueError("clock_seq is out of range (need a 14-bit value)") - clock_seq_low = clock_seq & 0xff - clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f) + clock_seq_low = clock_seq & 0xFF + clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3F) if node is None: node = random.getrandbits(48) - return uuid.UUID(fields=(time_low, time_mid, time_hi_version, - clock_seq_hi_variant, clock_seq_low, node), version=1) + return uuid.UUID( + fields=( + time_low, + time_mid, + time_hi_version, + clock_seq_hi_variant, + clock_seq_low, + node, + ), + version=1, + ) + -LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080') +LOWEST_TIME_UUID = uuid.UUID("00000000-0000-1000-8080-808080808080") """ The lowest possible TimeUUID, as sorted by Cassandra. """ -HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f') +HIGHEST_TIME_UUID = uuid.UUID("ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f") """ The highest possible TimeUUID, as sorted by Cassandra. """ @@ -199,12 +214,14 @@ def _addrinfo_or_none(contact_point, port): PYTHON-895. """ try: - value = socket.getaddrinfo(contact_point, port, - socket.AF_UNSPEC, socket.SOCK_STREAM) + value = socket.getaddrinfo( + contact_point, port, socket.AF_UNSPEC, socket.SOCK_STREAM + ) return value except socket.gaierror: - log.debug('Could not resolve hostname "{}" ' - 'with port {}'.format(contact_point, port)) + log.debug( + 'Could not resolve hostname "{}" with port {}'.format(contact_point, port) + ) return None @@ -223,230 +240,23 @@ def _addrinfo_to_ip_strings(addrinfo): def _resolve_contact_points_to_string_map(contact_points): return OrderedDict( - ('{cp}:{port}'.format(cp=cp, port=port), _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port))) + ( + "{cp}:{port}".format(cp=cp, port=port), + _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port)), + ) for cp, port in contact_points ) -class _IterationGuard(object): - # This context manager registers itself in the current iterators of the - # weak container, such as to delay all removals until the context manager - # exits. - # This technique should be relatively thread-safe (since sets are). - - def __init__(self, weakcontainer): - # Don't create cycles - self.weakcontainer = ref(weakcontainer) - - def __enter__(self): - w = self.weakcontainer() - if w is not None: - w._iterating.add(self) - return self - - def __exit__(self, e, t, b): - w = self.weakcontainer() - if w is not None: - s = w._iterating - s.remove(self) - if not s: - w._commit_removals() - - -class WeakSet(object): - def __init__(self, data=None): - self.data = set() - - def _remove(item, selfref=ref(self)): - self = selfref() - if self is not None: - if self._iterating: - self._pending_removals.append(item) - else: - self.data.discard(item) - - self._remove = _remove - # A list of keys to be removed - self._pending_removals = [] - self._iterating = set() - if data is not None: - self.update(data) - - def _commit_removals(self): - l = self._pending_removals - discard = self.data.discard - while l: - discard(l.pop()) - - def __iter__(self): - with _IterationGuard(self): - for itemref in self.data: - item = itemref() - if item is not None: - yield item - - def __len__(self): - return sum(x() is not None for x in self.data) - - def __contains__(self, item): - return ref(item) in self.data - - def __reduce__(self): - return (self.__class__, (list(self),), - getattr(self, '__dict__', None)) - - __hash__ = None - - def add(self, item): - if self._pending_removals: - self._commit_removals() - self.data.add(ref(item, self._remove)) - - def clear(self): - if self._pending_removals: - self._commit_removals() - self.data.clear() - - def copy(self): - return self.__class__(self) - - def pop(self): - if self._pending_removals: - self._commit_removals() - while True: - try: - itemref = self.data.pop() - except KeyError: - raise KeyError('pop from empty WeakSet') - item = itemref() - if item is not None: - return item - - def remove(self, item): - if self._pending_removals: - self._commit_removals() - self.data.remove(ref(item)) - - def discard(self, item): - if self._pending_removals: - self._commit_removals() - self.data.discard(ref(item)) - - def update(self, other): - if self._pending_removals: - self._commit_removals() - if isinstance(other, self.__class__): - self.data.update(other.data) - else: - for element in other: - self.add(element) - - def __ior__(self, other): - self.update(other) - return self - - # Helper functions for simple delegating methods. - def _apply(self, other, method): - if not isinstance(other, self.__class__): - other = self.__class__(other) - newdata = method(other.data) - newset = self.__class__() - newset.data = newdata - return newset - - def difference(self, other): - return self._apply(other, self.data.difference) - __sub__ = difference - - def difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.difference_update(ref(item) for item in other) - - def __isub__(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.difference_update(ref(item) for item in other) - return self - - def intersection(self, other): - return self._apply(other, self.data.intersection) - __and__ = intersection - - def intersection_update(self, other): - if self._pending_removals: - self._commit_removals() - self.data.intersection_update(ref(item) for item in other) - - def __iand__(self, other): - if self._pending_removals: - self._commit_removals() - self.data.intersection_update(ref(item) for item in other) - return self - - def issubset(self, other): - return self.data.issubset(ref(item) for item in other) - __lt__ = issubset - - def __le__(self, other): - return self.data <= set(ref(item) for item in other) - - def issuperset(self, other): - return self.data.issuperset(ref(item) for item in other) - __gt__ = issuperset - - def __ge__(self, other): - return self.data >= set(ref(item) for item in other) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.data == set(ref(item) for item in other) - - def symmetric_difference(self, other): - return self._apply(other, self.data.symmetric_difference) - __xor__ = symmetric_difference - - def symmetric_difference_update(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.symmetric_difference_update(ref(item) for item in other) - - def __ixor__(self, other): - if self._pending_removals: - self._commit_removals() - if self is other: - self.data.clear() - else: - self.data.symmetric_difference_update(ref(item) for item in other) - return self - - def union(self, other): - return self._apply(other, self.data.union) - __or__ = union - - def isdisjoint(self, other): - return len(self.intersection(other)) == 0 - - class SortedSet(object): - ''' + """ A sorted set based on sorted list A sorted set implementation is used in this case because it does not require its elements to be immutable/hashable. #Not implemented: update functions, inplace operators - ''' + """ def __init__(self, iterable=()): self._items = [] @@ -465,9 +275,7 @@ def __reversed__(self): return reversed(self._items) def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - self._items) + return "%s(%r)" % (self.__class__.__name__, self._items) def __reduce__(self): return self.__class__, (self._items,) @@ -477,7 +285,9 @@ def __eq__(self, other): return self._items == other._items else: try: - return len(other) == len(self._items) and all(item in self for item in other) + return len(other) == len(self._items) and all( + item in self for item in other + ) except TypeError: return NotImplemented @@ -486,7 +296,9 @@ def __ne__(self, other): return self._items != other._items else: try: - return len(other) != len(self._items) or any(item not in self for item in other) + return len(other) != len(self._items) or any( + item not in self for item in other + ) except TypeError: return NotImplemented @@ -504,6 +316,7 @@ def __gt__(self, other): def __and__(self, other): return self._intersect(other) + __rand__ = __and__ def __iand__(self, other): @@ -513,6 +326,7 @@ def __iand__(self, other): def __or__(self, other): return self.union(other) + __ror__ = __or__ def __ior__(self, other): @@ -533,6 +347,7 @@ def __isub__(self, other): def __xor__(self, other): return self.symmetric_difference(other) + __rxor__ = __xor__ def __ixor__(self, other): @@ -590,7 +405,7 @@ def remove(self, item): if self._items[i] == item: self._items.pop(i) return - raise KeyError('%r' % item) + raise KeyError("%r" % item) def union(self, *others): union = sortedset() @@ -645,8 +460,10 @@ def _find_insertion(self, x): try: while lo < hi: mid = (lo + hi) // 2 - if a[mid] < x: lo = mid + 1 - else: hi = mid + if a[mid] < x: + lo = mid + 1 + else: + hi = mid except TypeError: # could not compare a[mid] with x # start scanning to find insertion point while swallowing type errors @@ -654,18 +471,21 @@ def _find_insertion(self, x): compared_one = False # flag is used to determine whether uncomparables are grouped at the front or back while lo < hi: try: - if a[lo] == x or a[lo] >= x: break + if a[lo] == x or a[lo] >= x: + break compared_one = True except TypeError: - if compared_one: break + if compared_one: + break lo += 1 return lo + sortedset = SortedSet # backwards-compatibility class OrderedMap(Mapping): - ''' + """ An ordered map that accepts non-hashable types for keys. It also maintains the insertion order of items, behaving as OrderedDict in that regard. These maps are constructed and read just as normal mapping types, except that they may @@ -689,17 +509,17 @@ class OrderedMap(Mapping): This class derives from the (immutable) Mapping API. Objects in these maps are not intended be modified. - ''' + """ def __init__(self, *args, **kwargs): if len(args) > 1: - raise TypeError('expected at most 1 arguments, got %d' % len(args)) + raise TypeError("expected at most 1 arguments, got %d" % len(args)) self._items = [] self._index = {} if args: e = args[0] - if callable(getattr(e, 'keys', None)): + if callable(getattr(e, "keys", None)): for k in e.keys(): self._insert(k, e[k]) else: @@ -731,7 +551,9 @@ def __delitem__(self, key): # not efficient -- for convenience only try: index = self._index.pop(self._serialize_key(key)) - self._index = dict((k, i if i < index else i - 1) for k, i in self._index.items()) + self._index = dict( + (k, i if i < index else i - 1) for k, i in self._index.items() + ) self._items.pop(index) except KeyError: raise KeyError(str(key)) @@ -748,7 +570,9 @@ def __eq__(self, other): return self._items == other._items try: d = dict(other) - return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) + return len(d) == len(self._items) and all( + i[1] == d[i[0]] for i in self._items + ) except KeyError: return False except TypeError: @@ -756,12 +580,13 @@ def __eq__(self, other): return NotImplemented def __repr__(self): - return '%s([%s])' % ( + return "%s([%s])" % ( self.__class__.__name__, - ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) + ", ".join("(%r, %r)" % (k, v) for k, v in self._items), + ) def __str__(self): - return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items) + return "{%s}" % ", ".join("%r: %r" % (k, v) for k, v in self._items) def popitem(self): try: @@ -776,7 +601,6 @@ def _serialize_key(self, key): class OrderedMapSerializedKey(OrderedMap): - def __init__(self, cass_type, protocol_version): super(OrderedMapSerializedKey, self).__init__() self.cass_key_type = cass_type @@ -792,11 +616,11 @@ def _serialize_key(self, key): @total_ordering class Time(object): - ''' + """ Idealized time, independent of day. Up to nanosecond resolution - ''' + """ MICRO = 1000 MILLI = 1000 * MICRO @@ -822,7 +646,9 @@ def __init__(self, value): elif isinstance(value, str): self._from_timestring(value) else: - raise TypeError('Time arguments must be a whole number, datetime.time, or string') + raise TypeError( + "Time arguments must be a whole number, datetime.time, or string" + ) @property def hour(self): @@ -858,21 +684,29 @@ def time(self): """ Return a built-in datetime.time (nanosecond precision truncated to micros). """ - return datetime.time(hour=self.hour, minute=self.minute, second=self.second, - microsecond=self.nanosecond // Time.MICRO) + return datetime.time( + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.nanosecond // Time.MICRO, + ) def _from_timestamp(self, t): if t >= Time.DAY: - raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) + raise ValueError( + "value must be less than number of nanoseconds in a day (%d)" % Time.DAY + ) self.nanosecond_time = t def _from_timestring(self, s): try: - parts = s.split('.') + parts = s.split(".") base_time = time.strptime(parts[0], "%H:%M:%S") - self.nanosecond_time = (base_time.tm_hour * Time.HOUR + - base_time.tm_min * Time.MINUTE + - base_time.tm_sec * Time.SECOND) + self.nanosecond_time = ( + base_time.tm_hour * Time.HOUR + + base_time.tm_min * Time.MINUTE + + base_time.tm_sec * Time.SECOND + ) if len(parts) > 1: # right pad to 9 digits @@ -883,10 +717,12 @@ def _from_timestring(self, s): raise ValueError("can't interpret %r as a time" % (s,)) def _from_time(self, t): - self.nanosecond_time = (t.hour * Time.HOUR + - t.minute * Time.MINUTE + - t.second * Time.SECOND + - t.microsecond * Time.MICRO) + self.nanosecond_time = ( + t.hour * Time.HOUR + + t.minute * Time.MINUTE + + t.second * Time.SECOND + + t.microsecond * Time.MICRO + ) def __hash__(self): return self.nanosecond_time @@ -898,9 +734,16 @@ def __eq__(self, other): if isinstance(other, int): return self.nanosecond_time == other - return self.nanosecond_time % Time.MICRO == 0 and \ - datetime.time(hour=self.hour, minute=self.minute, second=self.second, - microsecond=self.nanosecond // Time.MICRO) == other + return ( + self.nanosecond_time % Time.MICRO == 0 + and datetime.time( + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.nanosecond // Time.MICRO, + ) + == other + ) def __ne__(self, other): return not self.__eq__(other) @@ -914,19 +757,23 @@ def __repr__(self): return "Time(%s)" % self.nanosecond_time def __str__(self): - return "%02d:%02d:%02d.%09d" % (self.hour, self.minute, - self.second, self.nanosecond) + return "%02d:%02d:%02d.%09d" % ( + self.hour, + self.minute, + self.second, + self.nanosecond, + ) @total_ordering class Date(object): - ''' + """ Idealized date: year, month, day Offers wider year range than datetime.date. For Dates that cannot be represented as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back to printing days_from_epoch offset. - ''' + """ MINUTE = 60 HOUR = 60 * MINUTE @@ -951,7 +798,9 @@ def __init__(self, value): elif isinstance(value, str): self._from_datestring(value) else: - raise TypeError('Date arguments must be a whole number, datetime.date, or string') + raise TypeError( + "Date arguments must be a whole number, datetime.date, or string" + ) @property def seconds(self): @@ -976,7 +825,7 @@ def _from_timetuple(self, t): self.days_from_epoch = calendar.timegm(t) // Date.DAY def _from_datestring(self, s): - if s[0] == '+': + if s[0] == "+": s = s[1:] dt = datetime.datetime.strptime(s, self.date_format) self._from_timetuple(dt.timetuple()) @@ -1024,12 +873,14 @@ def __str__(self): def _positional_rename_invalid_identifiers(field_names): names_out = list(field_names) for index, name in enumerate(field_names): - if (not all(c.isalnum() or c == '_' for c in name) + if ( + not all(c.isalnum() or c == "_" for c in name) or keyword.iskeyword(name) or not name or name[0].isdigit() - or name.startswith('_')): - names_out[index] = 'field_%d_' % index + or name.startswith("_") + ): + names_out[index] = "field_%d_" % index return names_out @@ -1101,10 +952,14 @@ def from_wkt(s): except ValueError: raise ValueError("Invalid WKT geometry: '{0}'".format(s)) - if geom['type'] != 'Point': - raise ValueError("Invalid WKT geometry type. Expected 'Point', got '{0}': '{1}'".format(geom['type'], s)) + if geom["type"] != "Point": + raise ValueError( + "Invalid WKT geometry type. Expected 'Point', got '{0}': '{1}'".format( + geom["type"], s + ) + ) - coords = geom['coordinates'] + coords = geom["coordinates"] if len(coords) < 2: x = y = _nan else: @@ -1123,6 +978,7 @@ class LineString(object): """ Tuple of (x, y) coordinates in the linestring """ + def __init__(self, coords=tuple()): """ 'coords`: a sequence of (x, y) coordinates of points in the linestring @@ -1141,7 +997,7 @@ def __str__(self): """ if not self.coords: return "LINESTRING EMPTY" - return "LINESTRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + return "LINESTRING (%s)" % ", ".join("%r %r" % (x, y) for x, y in self.coords) def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.coords) @@ -1159,12 +1015,16 @@ def from_wkt(s): except ValueError: raise ValueError("Invalid WKT geometry: '{0}'".format(s)) - if geom['type'] != 'LineString': - raise ValueError("Invalid WKT geometry type. Expected 'LineString', got '{0}': '{1}'".format(geom['type'], s)) + if geom["type"] != "LineString": + raise ValueError( + "Invalid WKT geometry type. Expected 'LineString', got '{0}': '{1}'".format( + geom["type"], s + ) + ) - geom['coordinates'] = list_contents_to_tuple(geom['coordinates']) + geom["coordinates"] = list_contents_to_tuple(geom["coordinates"]) - return LineString(coords=geom['coordinates']) + return LineString(coords=geom["coordinates"]) class _LinearRing(object): @@ -1182,7 +1042,7 @@ def __hash__(self): def __str__(self): if not self.coords: return "LINEARRING EMPTY" - return "LINEARRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + return "LINEARRING (%s)" % ", ".join("%r %r" % (x, y) for x, y in self.coords) def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.coords) @@ -1209,10 +1069,16 @@ def __init__(self, exterior=tuple(), interiors=None): `interiors`: None, or a sequence of sequences or (x, y) coordinates of points describing interior linear rings """ self.exterior = _LinearRing(exterior) - self.interiors = tuple(_LinearRing(e) for e in interiors) if interiors else tuple() + self.interiors = ( + tuple(_LinearRing(e) for e in interiors) if interiors else tuple() + ) def __eq__(self, other): - return isinstance(other, Polygon) and self.exterior == other.exterior and self.interiors == other.interiors + return ( + isinstance(other, Polygon) + and self.exterior == other.exterior + and self.interiors == other.interiors + ) def __hash__(self): return hash((self.exterior, self.interiors)) @@ -1224,11 +1090,17 @@ def __str__(self): if not self.exterior.coords: return "POLYGON EMPTY" rings = [ring.coords for ring in chain((self.exterior,), self.interiors)] - rings = ["(%s)" % ', '.join("%r %r" % (x, y) for x, y in ring) for ring in rings] - return "POLYGON (%s)" % ', '.join(rings) + rings = [ + "(%s)" % ", ".join("%r %r" % (x, y) for x, y in ring) for ring in rings + ] + return "POLYGON (%s)" % ", ".join(rings) def __repr__(self): - return "%s(%r, %r)" % (self.__class__.__name__, self.exterior.coords, [ring.coords for ring in self.interiors]) + return "%s(%r, %r)" % ( + self.__class__.__name__, + self.exterior.coords, + [ring.coords for ring in self.interiors], + ) @staticmethod def from_wkt(s): @@ -1243,17 +1115,24 @@ def from_wkt(s): except ValueError: raise ValueError("Invalid WKT geometry: '{0}'".format(s)) - if geom['type'] != 'Polygon': - raise ValueError("Invalid WKT geometry type. Expected 'Polygon', got '{0}': '{1}'".format(geom['type'], s)) + if geom["type"] != "Polygon": + raise ValueError( + "Invalid WKT geometry type. Expected 'Polygon', got '{0}': '{1}'".format( + geom["type"], s + ) + ) - coords = geom['coordinates'] + coords = geom["coordinates"] exterior = coords[0] if len(coords) > 0 else tuple() interiors = coords[1:] if len(coords) > 1 else None return Polygon(exterior=exterior, interiors=interiors) -_distance_wkt_pattern = re.compile("distance *\\( *\\( *([\\d\\.-]+) *([\\d+\\.-]+) *\\) *([\\d+\\.-]+) *\\) *$", re.IGNORECASE) +_distance_wkt_pattern = re.compile( + "distance *\\( *\\( *([\\d\\.-]+) *([\\d+\\.-]+) *\\) *([\\d+\\.-]+) *\\) *$", + re.IGNORECASE, +) class Distance(object): @@ -1282,7 +1161,12 @@ def __init__(self, x=_nan, y=_nan, radius=_nan): self.radius = radius def __eq__(self, other): - return isinstance(other, Distance) and self.x == other.x and self.y == other.y and self.radius == other.radius + return ( + isinstance(other, Distance) + and self.x == other.x + and self.y == other.y + and self.radius == other.radius + ) def __hash__(self): return hash((self.x, self.y, self.radius)) @@ -1329,18 +1213,25 @@ def __init__(self, months=0, days=0, nanoseconds=0): self.nanoseconds = nanoseconds def __eq__(self, other): - return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds + return ( + isinstance(other, self.__class__) + and self.months == other.months + and self.days == other.days + and self.nanoseconds == other.nanoseconds + ) def __repr__(self): - return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds) + return "Duration({0}, {1}, {2})".format( + self.months, self.days, self.nanoseconds + ) def __str__(self): has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0 - return '%s%dmo%dd%dns' % ( - '-' if has_negative_values else '', + return "%s%dmo%dd%dns" % ( + "-" if has_negative_values else "", abs(self.months), abs(self.days), - abs(self.nanoseconds) + abs(self.nanoseconds), ) @@ -1348,36 +1239,36 @@ class DateRangePrecision(object): """ An "enum" representing the valid values for :attr:`DateRange.precision`. """ - YEAR = 'YEAR' + + YEAR = "YEAR" """ """ - MONTH = 'MONTH' + MONTH = "MONTH" """ """ - DAY = 'DAY' + DAY = "DAY" """ """ - HOUR = 'HOUR' + HOUR = "HOUR" """ """ - MINUTE = 'MINUTE' + MINUTE = "MINUTE" """ """ - SECOND = 'SECOND' + SECOND = "SECOND" """ """ - MILLISECOND = 'MILLISECOND' + MILLISECOND = "MILLISECOND" """ """ - PRECISIONS = (YEAR, MONTH, DAY, HOUR, - MINUTE, SECOND, MILLISECOND) + PRECISIONS = (YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, MILLISECOND) """ """ @@ -1394,20 +1285,20 @@ def _round_to_precision(cls, ms, precision, default_dt): precision_idx = cls._to_int(precision) replace_kwargs = {} if precision_idx <= cls._to_int(DateRangePrecision.YEAR): - replace_kwargs['month'] = default_dt.month + replace_kwargs["month"] = default_dt.month if precision_idx <= cls._to_int(DateRangePrecision.MONTH): - replace_kwargs['day'] = default_dt.day + replace_kwargs["day"] = default_dt.day if precision_idx <= cls._to_int(DateRangePrecision.DAY): - replace_kwargs['hour'] = default_dt.hour + replace_kwargs["hour"] = default_dt.hour if precision_idx <= cls._to_int(DateRangePrecision.HOUR): - replace_kwargs['minute'] = default_dt.minute + replace_kwargs["minute"] = default_dt.minute if precision_idx <= cls._to_int(DateRangePrecision.MINUTE): - replace_kwargs['second'] = default_dt.second + replace_kwargs["second"] = default_dt.second if precision_idx <= cls._to_int(DateRangePrecision.SECOND): # truncate to nearest 1000 so we deal in ms, not us - replace_kwargs['microsecond'] = (default_dt.microsecond // 1000) * 1000 + replace_kwargs["microsecond"] = (default_dt.microsecond // 1000) * 1000 if precision_idx == cls._to_int(DateRangePrecision.MILLISECOND): - replace_kwargs['microsecond'] = int(round(dt.microsecond, -3)) + replace_kwargs["microsecond"] = int(round(dt.microsecond, -3)) return ms_timestamp_from_datetime(dt.replace(**replace_kwargs)) @classmethod @@ -1417,8 +1308,11 @@ def round_up_to_precision(cls, ms, precision): # be setting 31 as the month day if precision == cls.MONTH: date_ms = utc_datetime_from_ms_timestamp(ms) - upper_date = datetime.datetime.max.replace(year=date_ms.year, month=date_ms.month, - day=calendar.monthrange(date_ms.year, date_ms.month)[1]) + upper_date = datetime.datetime.max.replace( + year=date_ms.year, + month=date_ms.month, + day=calendar.monthrange(date_ms.year, date_ms.month)[1], + ) else: upper_date = datetime.datetime.max return cls._round_to_precision(ms, precision, upper_date) @@ -1447,6 +1341,7 @@ class DateRangeBound(object): For such values, string representions will show this offset rather than the CQL representation. """ + milliseconds = None precision = None @@ -1460,19 +1355,17 @@ def __init__(self, value, precision): try: self.precision = precision.upper() except AttributeError: - raise TypeError('precision must be a string; got %r' % precision) + raise TypeError("precision must be a string; got %r" % precision) if value is None: milliseconds = None elif isinstance(value, int): milliseconds = value elif isinstance(value, datetime.datetime): - value = value.replace( - microsecond=int(round(value.microsecond, -3)) - ) + value = value.replace(microsecond=int(round(value.microsecond, -3))) milliseconds = ms_timestamp_from_datetime(value) else: - raise ValueError('%r is not a valid value for DateRangeBound' % value) + raise ValueError("%r is not a valid value for DateRangeBound" % value) self.milliseconds = milliseconds self.validate() @@ -1480,12 +1373,16 @@ def __init__(self, value, precision): def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented - return (self.milliseconds == other.milliseconds and - self.precision == other.precision) + return ( + self.milliseconds == other.milliseconds + and self.precision == other.precision + ) def __lt__(self, other): - return ((str(self.milliseconds), str(self.precision)) < - (str(other.milliseconds), str(other.precision))) + return (str(self.milliseconds), str(self.precision)) < ( + str(other.milliseconds), + str(other.precision), + ) def datetime(self): """ @@ -1500,17 +1397,19 @@ def validate(self): return if None in attrs: raise TypeError( - ("%s.datetime and %s.precision must not be None unless both " - "are None; Got: %r") % (self.__class__.__name__, - self.__class__.__name__, - self) + ( + "%s.datetime and %s.precision must not be None unless both " + "are None; Got: %r" + ) + % (self.__class__.__name__, self.__class__.__name__, self) ) if self.precision not in DateRangePrecision.PRECISIONS: raise ValueError( - "%s.precision: expected value in %r; got %r" % ( + "%s.precision: expected value in %r; got %r" + % ( self.__class__.__name__, DateRangePrecision.PRECISIONS, - self.precision + self.precision, ) ) @@ -1529,7 +1428,7 @@ def from_value(cls, value): # if possible, use as a mapping try: - milliseconds, precision = value.get('milliseconds'), value.get('precision') + milliseconds, precision = value.get("milliseconds"), value.get("precision") except AttributeError: milliseconds = precision = None if milliseconds is not None and precision is not None: @@ -1555,23 +1454,23 @@ def round_down(self): return self _formatter_map = { - DateRangePrecision.YEAR: '%Y', - DateRangePrecision.MONTH: '%Y-%m', - DateRangePrecision.DAY: '%Y-%m-%d', - DateRangePrecision.HOUR: '%Y-%m-%dT%HZ', - DateRangePrecision.MINUTE: '%Y-%m-%dT%H:%MZ', - DateRangePrecision.SECOND: '%Y-%m-%dT%H:%M:%SZ', - DateRangePrecision.MILLISECOND: '%Y-%m-%dT%H:%M:%S', + DateRangePrecision.YEAR: "%Y", + DateRangePrecision.MONTH: "%Y-%m", + DateRangePrecision.DAY: "%Y-%m-%d", + DateRangePrecision.HOUR: "%Y-%m-%dT%HZ", + DateRangePrecision.MINUTE: "%Y-%m-%dT%H:%MZ", + DateRangePrecision.SECOND: "%Y-%m-%dT%H:%M:%SZ", + DateRangePrecision.MILLISECOND: "%Y-%m-%dT%H:%M:%S", } def __str__(self): if self == OPEN_BOUND: - return '*' + return "*" try: dt = self.datetime() except OverflowError: - return '%sms' % (self.milliseconds,) + return "%sms" % (self.milliseconds,) formatted = dt.strftime(self._formatter_map[self.precision]) @@ -1579,13 +1478,15 @@ def __str__(self): # we'd like to just format with '%Y-%m-%dT%H:%M:%S.%fZ', but %f # gives us more precision than we want, so we strftime up to %S and # do the rest ourselves - return '%s.%03dZ' % (formatted, dt.microsecond / 1000) + return "%s.%03dZ" % (formatted, dt.microsecond / 1000) return formatted def __repr__(self): - return '%s(milliseconds=%r, precision=%r)' % ( - self.__class__.__name__, self.milliseconds, self.precision + return "%s(milliseconds=%r, precision=%r)" % ( + self.__class__.__name__, + self.milliseconds, + self.precision, ) @@ -1617,6 +1518,7 @@ class DateRange(object): `datetime.datetime` cannot. For such values, string representions will show this offset rather than the CQL representation. """ + lower_bound = None upper_bound = None value = None @@ -1639,12 +1541,17 @@ def __init__(self, lower_bound=None, upper_bound=None, value=None): """ # if necessary, transform non-None args to DateRangeBounds - lower_bound = (DateRangeBound.from_value(lower_bound).round_down() - if lower_bound else lower_bound) - upper_bound = (DateRangeBound.from_value(upper_bound).round_up() - if upper_bound else upper_bound) - value = (DateRangeBound.from_value(value).round_down() - if value else value) + lower_bound = ( + DateRangeBound.from_value(lower_bound).round_down() + if lower_bound + else lower_bound + ) + upper_bound = ( + DateRangeBound.from_value(upper_bound).round_up() + if upper_bound + else upper_bound + ) + value = DateRangeBound.from_value(value).round_down() if value else value # if we're using a 2-ended range but one bound isn't specified, specify # an open bound @@ -1654,7 +1561,9 @@ def __init__(self, lower_bound=None, upper_bound=None, value=None): upper_bound = OPEN_BOUND self.lower_bound, self.upper_bound, self.value = ( - lower_bound, upper_bound, value + lower_bound, + upper_bound, + value, ) self.validate() @@ -1662,43 +1571,46 @@ def validate(self): if self.value is None: if self.lower_bound is None or self.upper_bound is None: raise ValueError( - '%s instances where value attribute is None must set ' - 'lower_bound or upper_bound; got %r' % ( - self.__class__.__name__, - self - ) + "%s instances where value attribute is None must set " + "lower_bound or upper_bound; got %r" + % (self.__class__.__name__, self) ) else: # self.value is not None if self.lower_bound is not None or self.upper_bound is not None: raise ValueError( - '%s instances where value attribute is not None must not ' - 'set lower_bound or upper_bound; got %r' % ( - self.__class__.__name__, - self - ) + "%s instances where value attribute is not None must not " + "set lower_bound or upper_bound; got %r" + % (self.__class__.__name__, self) ) def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented - return (self.lower_bound == other.lower_bound and - self.upper_bound == other.upper_bound and - self.value == other.value) + return ( + self.lower_bound == other.lower_bound + and self.upper_bound == other.upper_bound + and self.value == other.value + ) def __lt__(self, other): - return ((str(self.lower_bound), str(self.upper_bound), str(self.value)) < - (str(other.lower_bound), str(other.upper_bound), str(other.value))) + return (str(self.lower_bound), str(self.upper_bound), str(self.value)) < ( + str(other.lower_bound), + str(other.upper_bound), + str(other.value), + ) def __str__(self): if self.value: return str(self.value) else: - return '[%s TO %s]' % (self.lower_bound, self.upper_bound) + return "[%s TO %s]" % (self.lower_bound, self.upper_bound) def __repr__(self): - return '%s(lower_bound=%r, upper_bound=%r, value=%r)' % ( + return "%s(lower_bound=%r, upper_bound=%r, value=%r)" % ( self.__class__.__name__, - self.lower_bound, self.upper_bound, self.value + self.lower_bound, + self.upper_bound, + self.value, ) @@ -1720,22 +1632,28 @@ class Version(object): def __init__(self, version): self._version = version - if '-' in version: - version_without_prerelease, self.prerelease = version.split('-', 1) + if "-" in version: + version_without_prerelease, self.prerelease = version.split("-", 1) else: version_without_prerelease = version - parts = list(reversed(version_without_prerelease.split('.'))) + parts = list(reversed(version_without_prerelease.split("."))) if len(parts) > 4: prerelease_string = "-{}".format(self.prerelease) if self.prerelease else "" - log.warning("Unrecognized version: {}. Only 4 components plus prerelease are supported. " - "Assuming version as {}{}".format(version, '.'.join(parts[:-5:-1]), prerelease_string)) + log.warning( + "Unrecognized version: {}. Only 4 components plus prerelease are supported. " + "Assuming version as {}{}".format( + version, ".".join(parts[:-5:-1]), prerelease_string + ) + ) try: self.major = int(parts.pop()) except ValueError as e: raise ValueError( - "Couldn't parse version {}. Version should start with a number".format(version))\ - .with_traceback(e.__traceback__) + "Couldn't parse version {}. Version should start with a number".format( + version + ) + ).with_traceback(e.__traceback__) try: self.minor = int(parts.pop()) if parts else 0 self.patch = int(parts.pop()) if parts else 0 @@ -1747,14 +1665,22 @@ def __init__(self, version): except ValueError: self.build = build except ValueError: - assumed_version = "{}.{}.{}.{}-{}".format(self.major, self.minor, self.patch, self.build, self.prerelease) - log.warning("Unrecognized version {}. Assuming version as {}".format(version, assumed_version)) + assumed_version = "{}.{}.{}.{}-{}".format( + self.major, self.minor, self.patch, self.build, self.prerelease + ) + log.warning( + "Unrecognized version {}. Assuming version as {}".format( + version, assumed_version + ) + ) def __hash__(self): return self._version def __repr__(self): - version_string = "Version({0}, {1}, {2}".format(self.major, self.minor, self.patch) + version_string = "Version({0}, {1}, {2}".format( + self.major, self.minor, self.patch + ) if self.build: version_string += ", {}".format(self.build) if self.prerelease: @@ -1768,8 +1694,7 @@ def __str__(self): @staticmethod def _compare_version_part(version, other_version, cmp): - if not (isinstance(version, int) and - isinstance(other_version, int)): + if not (isinstance(version, int) and isinstance(other_version, int)): version = str(version) other_version = str(other_version) @@ -1779,12 +1704,15 @@ def __eq__(self, other): if not isinstance(other, Version): return NotImplemented - return (self.major == other.major and - self.minor == other.minor and - self.patch == other.patch and - self._compare_version_part(self.build, other.build, lambda s, o: s == o) and - self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s == o) - ) + return ( + self.major == other.major + and self.minor == other.minor + and self.patch == other.patch + and self._compare_version_part(self.build, other.build, lambda s, o: s == o) + and self._compare_version_part( + self.prerelease, other.prerelease, lambda s, o: s == o + ) + ) def __gt__(self, other): if not isinstance(other, Version): @@ -1793,8 +1721,12 @@ def __gt__(self, other): is_major_ge = self.major >= other.major is_minor_ge = self.minor >= other.minor is_patch_ge = self.patch >= other.patch - is_build_gt = self._compare_version_part(self.build, other.build, lambda s, o: s > o) - is_build_ge = self._compare_version_part(self.build, other.build, lambda s, o: s >= o) + is_build_gt = self._compare_version_part( + self.build, other.build, lambda s, o: s > o + ) + is_build_ge = self._compare_version_part( + self.build, other.build, lambda s, o: s >= o + ) # By definition, a prerelease comes BEFORE the actual release, so if a version # doesn't have a prerelease, it's automatically greater than anything that does @@ -1803,17 +1735,28 @@ def __gt__(self, other): elif other.prerelease and not self.prerelease: is_prerelease_gt = True else: - is_prerelease_gt = self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s > o) \ + is_prerelease_gt = self._compare_version_part( + self.prerelease, other.prerelease, lambda s, o: s > o + ) - return (self.major > other.major or - (is_major_ge and self.minor > other.minor) or - (is_major_ge and is_minor_ge and self.patch > other.patch) or - (is_major_ge and is_minor_ge and is_patch_ge and is_build_gt) or - (is_major_ge and is_minor_ge and is_patch_ge and is_build_ge and is_prerelease_gt) - ) + return ( + self.major > other.major + or (is_major_ge and self.minor > other.minor) + or (is_major_ge and is_minor_ge and self.patch > other.patch) + or (is_major_ge and is_minor_ge and is_patch_ge and is_build_gt) + or ( + is_major_ge + and is_minor_ge + and is_patch_ge + and is_build_ge + and is_prerelease_gt + ) + ) -def maybe_add_timeout_to_query(stmt: str, metadata_request_timeout: Optional[datetime.timedelta]) -> str: +def maybe_add_timeout_to_query( + stmt: str, metadata_request_timeout: Optional[datetime.timedelta] +) -> str: if metadata_request_timeout is None: return stmt ms = int(metadata_request_timeout / datetime.timedelta(milliseconds=1)) From 22d6027980704920bc5805f727138d8e9ee7d85d Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:11:17 +0200 Subject: [PATCH 11/18] remove: dead try/except around subprocess import in setup.py The subprocess module has been part of Python's standard library since Python 2.4. The try/except ImportError guard and has_subprocess flag were unreachable dead code that added unnecessary indentation and complexity to the doc-building logic. Replace with a direct 'import subprocess' and remove the conditional guard around the documentation build steps. --- setup.py | 314 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 201 insertions(+), 113 deletions(-) diff --git a/setup.py b/setup.py index 52e04a63e5..2549f5d0fa 100644 --- a/setup.py +++ b/setup.py @@ -19,20 +19,16 @@ from pathlib import Path from setuptools.command.build_ext import build_ext from setuptools import Extension, Command, setup -from setuptools.errors import (CCompilerError, PlatformError, - ExecError) +from setuptools.errors import CCompilerError, PlatformError, ExecError -try: - import subprocess - has_subprocess = True -except ImportError: - has_subprocess = False +import subprocess has_cqlengine = False -if __name__ == '__main__' and len(sys.argv) > 1 and sys.argv[1] == "install": +if __name__ == "__main__" and len(sys.argv) > 1 and sys.argv[1] == "install": try: import cqlengine + has_cqlengine = True except ImportError: pass @@ -41,11 +37,9 @@ class DocCommand(Command): - description = "generate or test documentation" - user_options = [("test", "t", - "run doctests instead of generating documentation")] + user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] @@ -61,6 +55,7 @@ def run(self): mode = "doctest" else: from cassandra import __version__ + path = "docs/_build/%s" % __version__ mode = "html" @@ -69,57 +64,83 @@ def run(self): except: pass - if has_subprocess: - # Prevent run with in-place extensions because cython-generated objects do not carry docstrings - # http://docs.cython.org/src/userguide/special_methods.html#docstrings - import glob - for f in glob.glob("cassandra/*.so"): - print("Removing '%s' to allow docs to run on pure python modules." %(f,)) - os.unlink(f) + # Prevent run with in-place extensions because cython-generated objects do not carry docstrings + # http://docs.cython.org/src/userguide/special_methods.html#docstrings + import glob - # Build io extension to make import and docstrings work - try: - output = subprocess.check_output( - ["python", "setup.py", "build_ext", "--inplace", "--force", "--no-murmur3", "--no-cython"], - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as exc: - raise RuntimeError("Documentation step '%s' failed: %s: %s" % ("build_ext", exc, exc.output)) - else: - print(output) + for f in glob.glob("cassandra/*.so"): + print("Removing '%s' to allow docs to run on pure python modules." % (f,)) + os.unlink(f) - try: - output = subprocess.check_output( - ["sphinx-build", "-b", mode, "docs", path], - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as exc: - raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) - else: - print(output) + # Build io extension to make import and docstrings work + try: + output = subprocess.check_output( + [ + sys.executable, + "setup.py", + "build_ext", + "--inplace", + "--force", + "--no-murmur3", + "--no-cython", + ], + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as exc: + raise RuntimeError( + "Documentation step '%s' failed: %s: %s" + % ("build_ext", exc, exc.output) + ) + else: + print(output) - print("") - print("Documentation step '%s' performed, results here:" % mode) - print(" file://%s/%s/index.html" % (os.path.dirname(os.path.realpath(__file__)), path)) + try: + output = subprocess.check_output( + [sys.executable, "-m", "sphinx", "-b", mode, "docs", path], + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as exc: + raise RuntimeError( + "Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output) + ) + else: + print(output) + print("") + print("Documentation step '%s' performed, results here:" % mode) + print( + " file://%s/%s/index.html" + % (os.path.dirname(os.path.realpath(__file__)), path) + ) -class BuildFailed(Exception): +class BuildFailed(Exception): def __init__(self, ext): self.ext = ext -is_windows = sys.platform.startswith('win32') -is_macos = sys.platform.startswith('darwin') + +is_windows = sys.platform.startswith("win32") +is_macos = sys.platform.startswith("darwin") + def get_subdriname(directory_path): try: # List only subdirectories in the given directory - subdirectories = [name for dir in directory_path for name in os.listdir(dir) - if os.path.isdir(os.path.join(directory_path, name))] + subdirectories = [ + name + for name in os.listdir(directory_path) + if os.path.isdir(os.path.join(directory_path, name)) + ] return subdirectories except Exception: return [] + def get_libev_headers_path(): - libev_hb_paths = ["/opt/homebrew/Cellar/libev", os.path.expanduser('~/homebrew/Cellar/libev')] + libev_hb_paths = [ + "/opt/homebrew/Cellar/libev", + os.path.expanduser("~/homebrew/Cellar/libev"), + ] for hb_path in libev_hb_paths: if not os.path.exists(hb_path): continue @@ -127,59 +148,70 @@ def get_libev_headers_path(): if not versions: continue picked_version = sorted(versions, reverse=True)[0] - resulted_path = os.path.join(hb_path, picked_version, 'include') + resulted_path = os.path.join(hb_path, picked_version, "include") warnings.warn("found libev headers in '%s'" % resulted_path) return [resulted_path] warnings.warn("did not find libev headers in '%s'" % libev_hb_paths) return [] -murmur3_ext = Extension('cassandra.cmurmur3', - sources=['cassandra/cmurmur3.c']) +murmur3_ext = Extension("cassandra.cmurmur3", sources=["cassandra/cmurmur3.c"]) -is_macos = sys.platform.startswith('darwin') def eval_env_var_as_array(varname): val = os.environ.get(varname) - return None if not val else [v.strip() for v in val.split(',')] + return None if not val else [v.strip() for v in val.split(",")] -DEFAULT_LIBEV_INCLUDES = ['/usr/include/libev', '/usr/local/include', '/opt/local/include', '/usr/include'] -DEFAULT_LIBEV_LIBDIRS = ['/usr/local/lib', '/opt/local/lib', '/usr/lib64'] -libev_includes = eval_env_var_as_array('CASS_DRIVER_LIBEV_INCLUDES') or DEFAULT_LIBEV_INCLUDES -libev_libdirs = eval_env_var_as_array('CASS_DRIVER_LIBEV_LIBS') or DEFAULT_LIBEV_LIBDIRS -if is_macos: - libev_includes.extend(['/opt/homebrew/include', os.path.expanduser('~/homebrew/include'), *get_libev_headers_path()]) - libev_libdirs.extend(['/opt/homebrew/lib']) +DEFAULT_LIBEV_INCLUDES = [ + "/usr/include/libev", + "/usr/local/include", + "/opt/local/include", + "/usr/include", +] +DEFAULT_LIBEV_LIBDIRS = ["/usr/local/lib", "/opt/local/lib", "/usr/lib64"] +libev_includes = ( + eval_env_var_as_array("CASS_DRIVER_LIBEV_INCLUDES") or DEFAULT_LIBEV_INCLUDES +) +libev_libdirs = eval_env_var_as_array("CASS_DRIVER_LIBEV_LIBS") or DEFAULT_LIBEV_LIBDIRS -conan_envfile = Path(__file__).parent / 'build-release/conan/conandeps.env' +if is_macos: + libev_includes.extend( + [ + "/opt/homebrew/include", + os.path.expanduser("~/homebrew/include"), + *get_libev_headers_path(), + ] + ) + libev_libdirs.extend(["/opt/homebrew/lib"]) + +conan_envfile = Path(__file__).parent / "build-release/conan/conandeps.env" if conan_envfile.exists(): conan_paths = json.loads(conan_envfile.read_text()) - libev_includes.extend([conan_paths.get('include_dirs')]) - libev_libdirs.extend([conan_paths.get('library_dirs')]) - -libev_ext = Extension('cassandra.io.libevwrapper', - sources=['cassandra/io/libevwrapper.c'], - include_dirs=libev_includes, - libraries=['ev'], - library_dirs=libev_libdirs) - -platform_unsupported_msg = \ -""" + libev_includes.extend([conan_paths.get("include_dirs")]) + libev_libdirs.extend([conan_paths.get("library_dirs")]) + +libev_ext = Extension( + "cassandra.io.libevwrapper", + sources=["cassandra/io/libevwrapper.c"], + include_dirs=libev_includes, + libraries=["ev"], + library_dirs=libev_libdirs, +) + +platform_unsupported_msg = """ =============================================================================== The optional C extensions are not supported on this platform. =============================================================================== """ -arch_unsupported_msg = \ -""" +arch_unsupported_msg = """ =============================================================================== The optional C extensions are not supported on big-endian systems. =============================================================================== """ -pypy_unsupported_msg = \ -""" +pypy_unsupported_msg = """ ================================================================================= Some optional C extensions are not supported in PyPy. Only murmur3 will be built. ================================================================================= @@ -196,20 +228,41 @@ def eval_env_var_as_array(varname): elif not is_supported_arch: sys.stderr.write(arch_unsupported_msg) -try_extensions = "--no-extensions" not in sys.argv and is_supported_platform and is_supported_arch and not os.environ.get('CASS_DRIVER_NO_EXTENSIONS') +try_extensions = ( + "--no-extensions" not in sys.argv + and is_supported_platform + and is_supported_arch + and not os.environ.get("CASS_DRIVER_NO_EXTENSIONS") +) try_murmur3 = try_extensions and "--no-murmur3" not in sys.argv -try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_LIBEV') -try_cython = try_extensions and "--no-cython" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_CYTHON') -sys.argv = [a for a in sys.argv if a not in ("--no-murmur3", "--no-libev", "--no-cython", "--no-extensions")] - -build_concurrency = int(os.environ.get('CASS_DRIVER_BUILD_CONCURRENCY', '0')) +try_libev = ( + try_extensions + and "--no-libev" not in sys.argv + and not is_pypy + and not os.environ.get("CASS_DRIVER_NO_LIBEV") +) +try_cython = ( + try_extensions + and "--no-cython" not in sys.argv + and not is_pypy + and not os.environ.get("CASS_DRIVER_NO_CYTHON") +) +sys.argv = [ + a + for a in sys.argv + if a not in ("--no-murmur3", "--no-libev", "--no-cython", "--no-extensions") +] + +build_concurrency = int(os.environ.get("CASS_DRIVER_BUILD_CONCURRENCY", "0")) if build_concurrency == 0: build_concurrency = None -CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST = bool(os.environ.get('CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST', 'no') == 'yes') +CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST = bool( + os.environ.get("CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST", "no") == "yes" +) -class NoPatchExtension(Extension): +class NoPatchExtension(Extension): # Older versions of setuptools.extension has a static flag which is set False before our # setup_requires lands Cython. It causes our *.pyx sources to be renamed to *.c in # the initializer. @@ -221,7 +274,7 @@ class NoPatchExtension(Extension): def __init__(self, *args, **kwargs): # bypass the patched init if possible if Extension.__bases__: - base, = Extension.__bases__ + (base,) = Extension.__bases__ base.__init__(self, *args, **kwargs) else: Extension.__init__(self, *args, **kwargs) @@ -230,7 +283,8 @@ def __init__(self, *args, **kwargs): class build_extensions(build_ext): _needs_stub = False - error_message = """ + error_message = ( + """ =============================================================================== WARNING: could not compile %s. @@ -242,7 +296,9 @@ class build_extensions(build_ext): This is often a matter of using vcvarsall.bat from your install directory, or running from a command prompt in the Visual Studio Tools Start Menu. =============================================================================== -""" if is_windows else """ +""" + if is_windows + else """ =============================================================================== WARNING: could not compile %s. @@ -279,33 +335,35 @@ class build_extensions(build_ext): =============================================================================== """ + ) def run(self): try: self._setup_extensions() build_ext.run(self) except PlatformError as exc: - sys.stderr.write('%s\n' % str(exc)) + sys.stderr.write("%s\n" % str(exc)) warnings.warn(self.error_message % "C extensions.") if CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST: raise - def build_extensions(self): if build_concurrency is None or build_concurrency > 1: self.check_extensions_list(self.extensions) import multiprocessing.pool - multiprocessing.pool.ThreadPool(processes=build_concurrency).map(self.build_extension, self.extensions) + + multiprocessing.pool.ThreadPool(processes=build_concurrency).map( + self.build_extension, self.extensions + ) else: build_ext.build_extensions(self) def build_extension(self, ext): try: build_ext.build_extension(self, fix_extension_class(ext)) - except (CCompilerError, ExecError, - PlatformError, IOError) as exc: - sys.stderr.write('%s\n' % str(exc)) + except (CCompilerError, ExecError, PlatformError, IOError) as exc: + sys.stderr.write("%s\n" % str(exc)) name = "The %s extension" % (ext.name,) warnings.warn(self.error_message % (name,)) if CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST: @@ -326,25 +384,49 @@ def _setup_extensions(self): if try_cython: try: from Cython.Build import cythonize - cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata', - 'pool', 'protocol', 'query', 'util', 'shard_info'] - compile_args = [] if is_windows else ['-Wno-unused-function'] - self.extensions.extend(cythonize( - [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], - extra_compile_args=compile_args) - for m in cython_candidates], - nthreads=build_concurrency, - compiler_directives={'language_level': 3}, - exclude_failures=not CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST, - )) - - self.extensions.extend(cythonize( - NoPatchExtension("*", ["cassandra/*.pyx"], extra_compile_args=compile_args), - nthreads=build_concurrency, - compiler_directives={'language_level': 3}, - )) + + cython_candidates = [ + "cluster", + "concurrent", + "connection", + "cqltypes", + "metadata", + "pool", + "protocol", + "query", + "util", + "shard_info", + ] + compile_args = [] if is_windows else ["-Wno-unused-function"] + self.extensions.extend( + cythonize( + [ + Extension( + "cassandra.%s" % m, + ["cassandra/%s.py" % m], + extra_compile_args=compile_args, + ) + for m in cython_candidates + ], + nthreads=build_concurrency, + compiler_directives={"language_level": 3}, + exclude_failures=not CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST, + ) + ) + + self.extensions.extend( + cythonize( + NoPatchExtension( + "*", ["cassandra/*.pyx"], extra_compile_args=compile_args + ), + nthreads=build_concurrency, + compiler_directives={"language_level": 3}, + ) + ) except Exception: - sys.stderr.write("Failed to cythonize one or more modules. These will not be compiled as extensions (optional).\n") + sys.stderr.write( + "Failed to cythonize one or more modules. These will not be compiled as extensions (optional).\n" + ) if CASS_DRIVER_BUILD_EXTENSIONS_ARE_MUST: raise @@ -357,15 +439,21 @@ def fix_extension_class(ext: Extension) -> Extension: def run_setup(extensions): - kw = {'cmdclass': {'doc': DocCommand}} - kw['cmdclass']['build_ext'] = build_extensions - kw['ext_modules'] = [Extension('DUMMY', [])] # dummy extension makes sure build_ext is called for install + kw = {"cmdclass": {"doc": DocCommand}} + kw["cmdclass"]["build_ext"] = build_extensions + kw["ext_modules"] = [ + Extension("DUMMY", []) + ] # dummy extension makes sure build_ext is called for install setup(**kw) + run_setup(None) if has_cqlengine: - warnings.warn("\n#######\n'cqlengine' package is present on path: %s\n" - "cqlengine is now an integrated sub-package of this driver.\n" - "It is recommended to remove this package to reduce the chance for conflicting usage" % cqlengine.__file__) + warnings.warn( + "\n#######\n'cqlengine' package is present on path: %s\n" + "cqlengine is now an integrated sub-package of this driver.\n" + "It is recommended to remove this package to reduce the chance for conflicting usage" + % cqlengine.__file__ + ) From 6f0026d0b1dc7680439e4c9610491a2f49337354 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:11:58 +0200 Subject: [PATCH 12/18] refactor: rename __nonzero__ to __bool__ in ResultSet Python 3 uses __bool__ for truth-value testing; __nonzero__ was the Python 2 equivalent. The code previously defined __nonzero__ and aliased __bool__ = __nonzero__ for cross-compatibility. Since Python 3 never calls __nonzero__, rename the method directly to __bool__ and remove the alias. --- cassandra/cluster.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a7f2b98a10..af8960504f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -6346,11 +6346,9 @@ def __getitem__(self, i): self._enter_list_mode("index operator") return self._current_rows[i] - def __nonzero__(self): + def __bool__(self): return bool(self._current_rows) - __bool__ = __nonzero__ - def get_query_trace(self, max_wait_sec=None): """ Gets the last query trace from the associated future. From b49d89420a8ee3a54adfda5e3a6444eaad676836 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:12:47 +0200 Subject: [PATCH 13/18] remove: delete test_export_keyspace_schema_udts that requires Python 2.7 This integration test was permanently dead code: it contained a guard 'if sys.version_info[0:2] != (2, 7): raise SkipTest(...)' which means it was always skipped on Python 3. The skip reason stated that the test compares static strings from dict items whose ordering is not deterministic on Python 3. Since the driver no longer supports Python 2, and fixing the test to use order-independent comparison is a separate concern, remove the permanently-skipped test entirely. --- tests/integration/standard/test_metadata.py | 1557 ++++++++++++------- 1 file changed, 1025 insertions(+), 532 deletions(-) diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index c30e369d83..3a26db9799 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -26,28 +26,67 @@ from unittest.mock import Mock, patch import pytest -from cassandra import AlreadyExists, SignatureDescriptor, UserFunctionDescriptor, UserAggregateDescriptor +from cassandra import ( + AlreadyExists, + SignatureDescriptor, + UserFunctionDescriptor, + UserAggregateDescriptor, +) from cassandra.connection import Connection from cassandra.encoder import Encoder -from cassandra.metadata import (IndexMetadata, Token, murmur3, Function, Aggregate, protect_name, protect_names, - RegisteredTableExtension, _RegisteredExtensionType, get_schema_parser, - group_keys_by_replica, NO_VALID_REPLICA) +from cassandra.metadata import ( + IndexMetadata, + Token, + murmur3, + Function, + Aggregate, + protect_name, + protect_names, + RegisteredTableExtension, + _RegisteredExtensionType, + get_schema_parser, + group_keys_by_replica, + NO_VALID_REPLICA, +) from cassandra.protocol import QueryMessage, ProtocolHandler from cassandra.util import SortedSet -from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, execute_until_pass, - BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, - BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, - greaterthanorequalcass30, lessthancass30, local, - get_supported_protocol_versions, greaterthancass20, - greaterthancass21, greaterthanorequalcass40, - lessthancass40, - TestCluster, requires_java_udf, requires_composite_type, - requires_collection_indexes, SCYLLA_VERSION, xfail_scylla, xfail_scylla_version_lt, - requirescompactstorage) - -from tests.util import wait_until, assertRegex, assertDictEqual, assertListEqual, assert_startswith_diff +from tests.integration import ( + get_cluster, + use_singledc, + PROTOCOL_VERSION, + execute_until_pass, + BasicSegregatedKeyspaceUnitTestCase, + BasicSharedKeyspaceUnitTestCase, + BasicExistingKeyspaceUnitTestCase, + drop_keyspace_shutdown_cluster, + CASSANDRA_VERSION, + greaterthanorequalcass30, + lessthancass30, + local, + get_supported_protocol_versions, + greaterthancass20, + greaterthancass21, + greaterthanorequalcass40, + lessthancass40, + TestCluster, + requires_java_udf, + requires_composite_type, + requires_collection_indexes, + SCYLLA_VERSION, + xfail_scylla, + xfail_scylla_version_lt, + requirescompactstorage, +) + +from tests.util import ( + wait_until, + assertRegex, + assertDictEqual, + assertListEqual, + assert_startswith_diff, +) log = logging.getLogger(__name__) @@ -75,7 +114,7 @@ def test_host_addresses(self): assert host.broadcast_rpc_address is not None assert host.host_id is not None - if CASSANDRA_VERSION >= Version('4-a'): + if CASSANDRA_VERSION >= Version("4-a"): assert host.broadcast_port is not None assert host.broadcast_rpc_port is not None @@ -85,16 +124,21 @@ def test_host_addresses(self): # The control connection node should have the listen address set. # Note: Scylla does not populate listen_address in system.local if SCYLLA_VERSION is None: - listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] + listen_addrs = [ + host.listen_address for host in self.cluster.metadata.all_hosts() + ] assert local_host in listen_addrs # The control connection node should have the broadcast_rpc_address set. - rpc_addrs = [host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts()] + rpc_addrs = [ + host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts() + ] assert local_host in rpc_addrs @unittest.skipUnless( - os.getenv('MAPPED_CASSANDRA_VERSION', None) is not None, - "Don't check the host version for test-dse") + os.getenv("MAPPED_CASSANDRA_VERSION", None) is not None, + "Don't check the host version for test-dse", + ) def test_host_release_version(self): """ Checks the hosts release version and validates that it is equal to the @@ -110,12 +154,12 @@ def test_host_release_version(self): assert host.release_version.startswith(CASSANDRA_VERSION.base_version) - @local class MetaDataRemovalTest(unittest.TestCase): - def setUp(self): - self.cluster = TestCluster(contact_points=['127.0.0.1', '127.0.0.2', '127.0.0.3', '126.0.0.186']) + self.cluster = TestCluster( + contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3", "126.0.0.186"] + ) self.cluster.connect() def tearDown(self): @@ -132,15 +176,18 @@ def test_bad_contact_point(self): @test_category metadata """ # wait until we have only 3 hosts - wait_until(condition=lambda: len(self.cluster.metadata.all_hosts()) == 3, delay=0.5, max_attempts=5) + wait_until( + condition=lambda: len(self.cluster.metadata.all_hosts()) == 3, + delay=0.5, + max_attempts=5, + ) # verify the un-existing host was filtered for host in self.cluster.metadata.all_hosts(): - assert host.endpoint.address != '126.0.0.186' + assert host.endpoint.address != "126.0.0.186" class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): - def test_schema_metadata_disable(self): """ Checks to ensure that schema metadata_enabled, and token_metadata_enabled @@ -157,7 +204,7 @@ def test_schema_metadata_disable(self): no_schema = TestCluster(schema_metadata_enabled=False) no_schema_session = no_schema.connect() assert len(no_schema.metadata.keyspaces) == 0 - assert no_schema.metadata.export_schema_as_string() == '' + assert no_schema.metadata.export_schema_as_string() == "" no_token = TestCluster(token_metadata_enabled=False) no_token_session = no_token.connect() assert len(no_token.metadata.token_map.token_to_host_owner) == 0 @@ -171,18 +218,27 @@ def test_schema_metadata_disable(self): no_schema.shutdown() no_token.shutdown() - def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None): + def make_create_statement( + self, partition_cols, clustering_cols=None, other_cols=None + ): clustering_cols = clustering_cols or [] other_cols = other_cols or [] - statement = "CREATE TABLE %s.%s (" % (self.keyspace_name, self.function_table_name) + statement = "CREATE TABLE %s.%s (" % ( + self.keyspace_name, + self.function_table_name, + ) if len(partition_cols) == 1 and not clustering_cols: statement += "%s text PRIMARY KEY, " % protect_name(partition_cols[0]) else: - statement += ", ".join("%s text" % protect_name(col) for col in partition_cols) + statement += ", ".join( + "%s text" % protect_name(col) for col in partition_cols + ) statement += ", " - statement += ", ".join("%s text" % protect_name(col) for col in clustering_cols + other_cols) + statement += ", ".join( + "%s text" % protect_name(col) for col in clustering_cols + other_cols + ) if len(partition_cols) != 1 or clustering_cols: statement += ", PRIMARY KEY (" @@ -204,18 +260,28 @@ def make_create_statement(self, partition_cols, clustering_cols=None, other_cols def check_create_statement(self, tablemeta, original): recreate = tablemeta.as_cql_query(formatted=False) - assert original == recreate[:len(original)] - execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + assert original == recreate[: len(original)] + execute_until_pass( + self.session, + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name), + ) execute_until_pass(self.session, recreate) # create the table again, but with formatting enabled - execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + execute_until_pass( + self.session, + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name), + ) recreate = tablemeta.as_cql_query(formatted=True) execute_until_pass(self.session, recreate) def get_table_metadata(self): - self.cluster.refresh_table_metadata(self.keyspace_name, self.function_table_name) - return self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name] + self.cluster.refresh_table_metadata( + self.keyspace_name, self.function_table_name + ) + return self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_table_name + ] def test_basic_table_meta_properties(self): create_statement = self.make_create_statement(["a"], [], ["b", "c"]) @@ -230,7 +296,7 @@ def test_basic_table_meta_properties(self): assert ksmeta.name == self.keyspace_name assert ksmeta.durable_writes - assert ksmeta.replication_strategy.name == 'SimpleStrategy' + assert ksmeta.replication_strategy.name == "SimpleStrategy" assert ksmeta.replication_strategy.replication_factor == 1 assert self.function_table_name in ksmeta.tables @@ -239,9 +305,9 @@ def test_basic_table_meta_properties(self): assert tablemeta.name == self.function_table_name assert tablemeta.name == self.function_table_name - assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert ["a"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) cc = self.cluster.control_connection._connection parser = get_schema_parser( @@ -263,9 +329,9 @@ def test_compound_primary_keys(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -275,21 +341,23 @@ def test_compound_primary_keys_protected(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'Aa'] == [c.name for c in tablemeta.partition_key] - assert [u'Bb'] == [c.name for c in tablemeta.clustering_key] - assert [u'Aa', u'Bb', u'Cc'] == sorted(tablemeta.columns.keys()) + assert ["Aa"] == [c.name for c in tablemeta.partition_key] + assert ["Bb"] == [c.name for c in tablemeta.clustering_key] + assert ["Aa", "Bb", "Cc"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) def test_compound_primary_keys_more_columns(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd', u'e', u'f'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b", "c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d", "e", "f"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -298,9 +366,9 @@ def test_composite_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -310,9 +378,9 @@ def test_composite_in_compound_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] - assert [u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd', u'e'] == sorted(tablemeta.columns.keys()) + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] + assert ["c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d", "e"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -322,9 +390,9 @@ def test_compound_primary_keys_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -345,9 +413,9 @@ def test_cluster_column_ordering_reversed_metadata(self): create_statement += " WITH CLUSTERING ORDER BY (b ASC, c DESC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() - b_column = tablemeta.columns['b'] + b_column = tablemeta.columns["b"] assert not b_column.is_reversed - c_column = tablemeta.columns['c'] + c_column = tablemeta.columns["c"] assert c_column.is_reversed def test_compound_primary_keys_more_columns_compact(self): @@ -356,9 +424,9 @@ def test_compound_primary_keys_more_columns_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] - assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a"] == [c.name for c in tablemeta.partition_key] + assert ["b", "c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -367,9 +435,9 @@ def test_composite_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -379,9 +447,9 @@ def test_composite_in_compound_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] - assert [u'c'] == [c.name for c in tablemeta.clustering_key] - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a", "b"] == [c.name for c in tablemeta.partition_key] + assert ["c"] == [c.name for c in tablemeta.clustering_key] + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -394,9 +462,9 @@ def test_cql_compatibility(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert ["a"] == [c.name for c in tablemeta.partition_key] assert [] == tablemeta.clustering_key - assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) + assert ["a", "b", "c", "d"] == sorted(tablemeta.columns.keys()) assert tablemeta.is_cql_compatible @@ -415,7 +483,9 @@ def test_compound_primary_keys_ordering(self): self.check_create_statement(tablemeta, create_statement) def test_compound_primary_keys_more_columns_ordering(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b DESC, c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -450,8 +520,7 @@ def test_dense_compact_storage(self): def test_counter(self): create_statement = ( - "CREATE TABLE {keyspace}.{table} (" - "key text PRIMARY KEY, a1 counter)" + "CREATE TABLE {keyspace}.{table} (key text PRIMARY KEY, a1 counter)" ).format(keyspace=self.keyspace_name, table=self.function_table_name) self.session.execute(create_statement) @@ -461,7 +530,7 @@ def test_counter(self): @lessthancass40 @requirescompactstorage def test_counter_with_compact_storage(self): - """ PYTHON-1100 """ + """PYTHON-1100""" create_statement = ( "CREATE TABLE {keyspace}.{table} (" "key text PRIMARY KEY, a1 counter) WITH COMPACT STORAGE" @@ -483,20 +552,28 @@ def test_counter_with_dense_compact_storage(self): tablemeta = self.get_table_metadata() self.check_create_statement(tablemeta, create_statement) - @pytest.mark.skip(reason='https://github.com/scylladb/scylladb/issues/6058') + @pytest.mark.skip(reason="https://github.com/scylladb/scylladb/issues/6058") def test_indexes(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) + create_statement = self.make_create_statement( + ["a"], ["b", "c"], ["d", "e", "f"] + ) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" execute_until_pass(self.session, create_statement) - d_index = "CREATE INDEX d_index ON %s.%s (d)" % (self.keyspace_name, self.function_table_name) - e_index = "CREATE INDEX e_index ON %s.%s (e)" % (self.keyspace_name, self.function_table_name) + d_index = "CREATE INDEX d_index ON %s.%s (d)" % ( + self.keyspace_name, + self.function_table_name, + ) + e_index = "CREATE INDEX e_index ON %s.%s (e)" % ( + self.keyspace_name, + self.function_table_name, + ) execute_until_pass(self.session, d_index) execute_until_pass(self.session, e_index) tablemeta = self.get_table_metadata() statements = tablemeta.export_as_string().strip() - statements = [s.strip() for s in statements.split(';')] + statements = [s.strip() for s in statements.split(";")] statements = list(filter(bool, statements)) assert 3 == len(statements) assert d_index in statements @@ -505,40 +582,56 @@ def test_indexes(self): # make sure indexes are included in KeyspaceMetadata.export_as_string() ksmeta = self.cluster.metadata.keyspaces[self.keyspace_name] statement = ksmeta.export_as_string() - assert 'CREATE INDEX d_index' in statement - assert 'CREATE INDEX e_index' in statement + assert "CREATE INDEX d_index" in statement + assert "CREATE INDEX e_index" in statement @greaterthancass21 @requires_collection_indexes - @xfail_scylla('scylladb/scylladb#22013 - scylla does not show full index in system_schema.indexes') + @xfail_scylla( + "scylladb/scylladb#22013 - scylla does not show full index in system_schema.indexes" + ) def test_collection_indexes(self): - self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" - % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index1 ON %s.%s (keys(b))" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" + % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE INDEX index1 ON %s.%s (keys(b))" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - assert '(keys(b))' in tablemeta.export_as_string() + assert "(keys(b))" in tablemeta.export_as_string() self.session.execute("DROP INDEX %s.index1" % (self.keyspace_name,)) - self.session.execute("CREATE INDEX index2 ON %s.%s (b)" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE INDEX index2 ON %s.%s (b)" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - target = ' (b)' if CASSANDRA_VERSION < Version("3.0") else 'values(b))' # explicit values in C* 3+ + target = ( + " (b)" if CASSANDRA_VERSION < Version("3.0") else "values(b))" + ) # explicit values in C* 3+ assert target in tablemeta.export_as_string() # test full indexes on frozen collections, if available if CASSANDRA_VERSION >= Version("2.1.3"): - self.session.execute("DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" - % (self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index3 ON %s.%s (full(b))" - % (self.keyspace_name, self.function_table_name)) + self.session.execute( + "DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" + % (self.keyspace_name, self.function_table_name) + ) + self.session.execute( + "CREATE INDEX index3 ON %s.%s (full(b))" + % (self.keyspace_name, self.function_table_name) + ) tablemeta = self.get_table_metadata() - assert '(full(b))' in tablemeta.export_as_string() + assert "(full(b))" in tablemeta.export_as_string() def test_compression_disabled(self): create_statement = self.make_create_statement(["a"], ["b"], ["c"]) @@ -572,7 +665,7 @@ def test_non_size_tiered_compaction(self): assert "'tombstone_threshold': '0.3'" in cql assert "LeveledCompactionStrategy" in cql # formerly legacy options; reintroduced in 4.0 - if CASSANDRA_VERSION < Version('4.0-a'): + if CASSANDRA_VERSION < Version("4.0-a"): assert "min_threshold" not in cql assert "max_threshold" not in cql @@ -601,56 +694,89 @@ def test_refresh_schema_metadata(self): assert "new_keyspace" not in cluster2.metadata.keyspaces # Cluster metadata modification - self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + self.session.execute( + "CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}" + ) assert "new_keyspace" not in cluster2.metadata.keyspaces cluster2.refresh_schema_metadata() assert "new_keyspace" in cluster2.metadata.keyspaces # Keyspace metadata modification - self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.session.execute( + "ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_schema_metadata() assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes # Table metadata modification table_name = "test" - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, table_name + ) + ) cluster2.refresh_schema_metadata() - self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + self.session.execute( + "ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name) + ) + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.refresh_schema_metadata() - assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) if PROTOCOL_VERSION >= 3: # UDT metadata modification - self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_schema_metadata() assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types if PROTOCOL_VERSION >= 4: # UDF metadata modification - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key+val;';""".format( + self.keyspace_name + ) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} cluster2.refresh_schema_metadata() - assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions + assert ( + "sum_int(int,int)" + in cluster2.metadata.keyspaces[self.keyspace_name].functions + ) # UDA metadata modification - self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + self.session.execute( + """CREATE AGGREGATE {0}.sum_agg(int) SFUNC sum_int STYPE int - INITCOND 0""" - .format(self.keyspace_name)) + INITCOND 0""".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} cluster2.refresh_schema_metadata() - assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + assert ( + "sum_agg(int)" + in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + ) # Cluster metadata modification self.session.execute("DROP KEYSPACE new_keyspace") @@ -682,7 +808,9 @@ def test_refresh_keyspace_metadata(self): cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes - self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) + self.session.execute( + "ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_keyspace_metadata(self.keyspace_name) assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes @@ -707,17 +835,38 @@ def test_refresh_table_metadata(self): """ table_name = "test" - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, table_name + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns - self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) + self.session.execute( + "ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name) + ) + assert ( + "c" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.refresh_table_metadata(self.keyspace_name, table_name) - assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns + assert ( + "c" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[table_name] + .columns + ) cluster2.shutdown() @@ -741,44 +890,92 @@ def test_refresh_metadata_for_mv(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format( + self.keyspace_name, self.function_table_name + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() try: - assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " - "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)" - .format(self.keyspace_name, self.function_table_name)) - assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) + self.session.execute( + "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " + "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( + self.keyspace_name, self.function_table_name + ) + ) + assert ( + "mv1" + not in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) cluster2.refresh_table_metadata(self.keyspace_name, "mv1") - assert "mv1" in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + in cluster2.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) finally: cluster2.shutdown() - original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] - assert original_meta is self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] - self.cluster.refresh_materialized_view_metadata(self.keyspace_name, 'mv1') + original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] + assert ( + original_meta + is self.session.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + ) + self.cluster.refresh_materialized_view_metadata(self.keyspace_name, "mv1") - current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] assert current_meta is not original_meta - assert original_meta is not self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] + assert ( + original_meta + is not self.session.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + ) assert original_meta.as_cql_query() == current_meta.as_cql_query() cluster3 = TestCluster(schema_event_refresh_window=-1) cluster3.connect() try: - assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv2" + not in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT a, b FROM {0}.{1} " "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) + ) + assert ( + "mv2" + not in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) + cluster3.refresh_materialized_view_metadata(self.keyspace_name, "mv2") + assert ( + "mv2" + in cluster3.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views ) - assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views - cluster3.refresh_materialized_view_metadata(self.keyspace_name, 'mv2') - assert "mv2" in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views finally: cluster3.shutdown() @@ -800,13 +997,19 @@ def test_refresh_user_type_metadata(self): """ if PROTOCOL_VERSION < 3: - raise unittest.SkipTest("Protocol 3+ is required for UDTs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 3+ is required for UDTs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} - self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_user_type_metadata(self.keyspace_name, "user") @@ -827,23 +1030,55 @@ def test_refresh_user_type_metadata_proto_2(self): """ supported_versions = get_supported_protocol_versions() if 2 not in supported_versions: # 1 and 2 were dropped in the same version - raise unittest.SkipTest("Protocol versions 1 and 2 are not supported in Cassandra version ".format(CASSANDRA_VERSION)) + raise unittest.SkipTest( + "Protocol versions 1 and 2 are not supported in Cassandra version ".format( + CASSANDRA_VERSION + ) + ) for protocol_version in (1, 2): cluster = TestCluster() session = cluster.connect() assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} - session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + session.execute( + "CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name) + ) assert "user" in cluster.metadata.keyspaces[self.keyspace_name].user_types - assert "age" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names - assert "name" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + assert ( + "age" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) + assert ( + "name" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) - session.execute("ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name)) - assert "flag" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + session.execute( + "ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name) + ) + assert ( + "flag" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) - session.execute("ALTER TYPE {0}.user RENAME flag TO something".format(self.keyspace_name)) - assert "something" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + session.execute( + "ALTER TYPE {0}.user RENAME flag TO something".format( + self.keyspace_name + ) + ) + assert ( + "something" + in cluster.metadata.keyspaces[self.keyspace_name] + .user_types["user"] + .field_names + ) session.execute("DROP TYPE {0}.user".format(self.keyspace_name)) assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} @@ -869,20 +1104,33 @@ def test_refresh_user_function_metadata(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDFs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 4+ is required for UDFs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS ' return key + val;';""".format(self.keyspace_name)) + LANGUAGE java AS ' return key + val;';""".format( + self.keyspace_name + ) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} - cluster2.refresh_user_function_metadata(self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"])) - assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions + cluster2.refresh_user_function_metadata( + self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"]) + ) + assert ( + "sum_int(int,int)" + in cluster2.metadata.keyspaces[self.keyspace_name].functions + ) cluster2.shutdown() @@ -906,26 +1154,39 @@ def test_refresh_user_aggregate_metadata(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDAs, currently testing against {0}".format(PROTOCOL_VERSION)) + raise unittest.SkipTest( + "Protocol 4+ is required for UDAs, currently testing against {0}".format( + PROTOCOL_VERSION + ) + ) cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + self.session.execute( + """CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key + val;';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key + val;';""".format( + self.keyspace_name + ) + ) - self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + self.session.execute( + """CREATE AGGREGATE {0}.sum_agg(int) SFUNC sum_int STYPE int - INITCOND 0""" - .format(self.keyspace_name)) + INITCOND 0""".format(self.keyspace_name) + ) assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} - cluster2.refresh_user_aggregate_metadata(self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"])) - assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + cluster2.refresh_user_aggregate_metadata( + self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"]) + ) + assert ( + "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + ) cluster2.shutdown() @@ -944,14 +1205,30 @@ def test_multiple_indices(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format(self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index_1 ON {0}.{1}(b)".format(self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE INDEX index_2 ON {0}.{1}(keys(b))".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format( + self.keyspace_name, self.function_table_name + ) + ) + self.session.execute( + "CREATE INDEX index_1 ON {0}.{1}(b)".format( + self.keyspace_name, self.function_table_name + ) + ) + self.session.execute( + "CREATE INDEX index_2 ON {0}.{1}(keys(b))".format( + self.keyspace_name, self.function_table_name + ) + ) - indices = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].indexes + indices = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .indexes + ) assert len(indices) == 2 index_1 = indices["index_1"] - index_2 = indices['index_2'] + index_2 = indices["index_2"] assert index_1.table_name == "test_multiple_indices" assert index_1.name == "index_1" assert index_1.kind == "COMPOSITES" @@ -969,7 +1246,7 @@ def test_table_extensions(self): ks = self.keyspace_name ks_meta = s.cluster.metadata.keyspaces[ks] t = self.function_table_name - v = t + 'view' + v = t + "view" s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t)) s.execute( @@ -993,7 +1270,7 @@ def after_table_cql(cls, table_meta, ext_key, ext_blob): return "%s %s %s %s" % (cls.name, table_meta.name, ext_key, ext_blob) class Ext1(Ext0): - name = t + '##' + name = t + "##" assert Ext0.name in _RegisteredExtensionType._extension_registry assert Ext1.name in _RegisteredExtensionType._extension_registry @@ -1007,13 +1284,22 @@ class Ext1(Ext0): assert table_meta.export_as_string() == original_table_cql assert view_meta.export_as_string() == original_view_cql - update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing - update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?') + update_t = s.prepare( + "UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?" + ) # for blob type coercing + update_v = s.prepare( + "UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?" + ) # extensions registered, one present # -------------------------------------- ext_map = {Ext0.name: b"THA VALUE"} - [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) - for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + [ + ( + s.execute(update_t, (ext_map, ks, t)), + s.execute(update_v, (ext_map, ks, v)), + ) + for _ in self.cluster.metadata.all_hosts() + ] # we're manipulating metadata - do it on all hosts self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_materialized_view_metadata(ks, v) table_meta = ks_meta.tables[t] @@ -1022,7 +1308,9 @@ class Ext1(Ext0): assert Ext0.name in table_meta.extensions new_cql = table_meta.export_as_string() assert new_cql != original_table_cql - assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert ( + Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + ) assert Ext1.name not in new_cql assert Ext0.name in view_meta.extensions @@ -1033,10 +1321,14 @@ class Ext1(Ext0): # extensions registered, one present # -------------------------------------- - ext_map = {Ext0.name: b"THA VALUE", - Ext1.name: b"OTHA VALUE"} - [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) - for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + ext_map = {Ext0.name: b"THA VALUE", Ext1.name: b"OTHA VALUE"} + [ + ( + s.execute(update_t, (ext_map, ks, t)), + s.execute(update_v, (ext_map, ks, v)), + ) + for _ in self.cluster.metadata.all_hosts() + ] # we're manipulating metadata - do it on all hosts self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_materialized_view_metadata(ks, v) table_meta = ks_meta.tables[t] @@ -1046,8 +1338,12 @@ class Ext1(Ext0): assert Ext1.name in table_meta.extensions new_cql = table_meta.export_as_string() assert new_cql != original_table_cql - assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql - assert Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]) in new_cql + assert ( + Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + ) + assert ( + Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]) in new_cql + ) assert Ext0.name in view_meta.extensions assert Ext1.name in view_meta.extensions @@ -1059,8 +1355,10 @@ class Ext1(Ext0): def test_metadata_pagination(self): self.cluster.refresh_schema_metadata() for i in range(12): - self.session.execute("CREATE TABLE %s.%s_%d (a int PRIMARY KEY, b map)" - % (self.keyspace_name, self.function_table_name, i)) + self.session.execute( + "CREATE TABLE %s.%s_%d (a int PRIMARY KEY, b map)" + % (self.keyspace_name, self.function_table_name, i) + ) self.cluster.schema_metadata_page_size = 5 self.cluster.refresh_schema_metadata() @@ -1093,7 +1391,6 @@ def test_metadata_pagination_keyspaces(self): class TestCodeCoverage(unittest.TestCase): - def test_export_schema(self): """ Test export schema functionality @@ -1119,85 +1416,12 @@ def test_export_keyspace_schema(self): assert isinstance(keyspace_metadata.as_cql_query(), str) cluster.shutdown() - @greaterthancass20 - def test_export_keyspace_schema_udts(self): - """ - Test udt exports - """ - - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest( - "Protocol 3.0+ is required for UDT change events, currently testing against %r" - % (PROTOCOL_VERSION,)) - - if sys.version_info[0:2] != (2, 7): - raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') - - cluster = TestCluster() - session = cluster.connect() - - session.execute(""" - CREATE KEYSPACE export_udts - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - AND durable_writes = true; - """) - session.execute(""" - CREATE TYPE export_udts.street ( - street_number int, - street_name text) - """) - session.execute(""" - CREATE TYPE export_udts.zip ( - zipcode int, - zip_plus_4 int) - """) - session.execute(""" - CREATE TYPE export_udts.address ( - street_address frozen, - zip_code frozen) - """) - session.execute(""" - CREATE TABLE export_udts.users ( - user text PRIMARY KEY, - addresses map>) - """) - - expected_prefix = """CREATE KEYSPACE export_udts WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; - -CREATE TYPE export_udts.street ( - street_number int, - street_name text -); - -CREATE TYPE export_udts.zip ( - zipcode int, - zip_plus_4 int -); - -CREATE TYPE export_udts.address ( - street_address frozen, - zip_code frozen -); - -CREATE TABLE export_udts.users ( - user text PRIMARY KEY, - addresses map>""" - - assert_startswith_diff(cluster.metadata.keyspaces['export_udts'].export_as_string(), expected_prefix) - - table_meta = cluster.metadata.keyspaces['export_udts'].tables['users'] - - expected_prefix = """CREATE TABLE export_udts.users ( - user text PRIMARY KEY, - addresses map>""" - - assert_startswith_diff(table_meta.export_as_string(), expected_prefix) - - cluster.shutdown() - @greaterthancass21 - @xfail_scylla_version_lt(reason='scylladb/scylladb#10707 - Column name in CREATE INDEX is not quoted', - oss_scylla_version="5.2", ent_scylla_version="2023.1.1") + @xfail_scylla_version_lt( + reason="scylladb/scylladb#10707 - Column name in CREATE INDEX is not quoted", + oss_scylla_version="5.2", + ent_scylla_version="2023.1.1", + ) def test_case_sensitivity(self): """ Test that names that need to be escaped in CREATE statements are @@ -1206,15 +1430,19 @@ def test_case_sensitivity(self): cluster = TestCluster() session = cluster.connect() - ksname = 'AnInterestingKeyspace' - cfname = 'AnInterestingTable' + ksname = "AnInterestingKeyspace" + cfname = "AnInterestingTable" session.execute("DROP KEYSPACE IF EXISTS {0}".format(ksname)) - session.execute(""" + session.execute( + """ CREATE KEYSPACE "%s" WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} - """ % (ksname,)) - session.execute(""" + """ + % (ksname,) + ) + session.execute( + """ CREATE TABLE "%s"."%s" ( k int, "A" int, @@ -1222,13 +1450,21 @@ def test_case_sensitivity(self): "MyColumn" int, PRIMARY KEY (k, "A")) WITH CLUSTERING ORDER BY ("A" DESC) - """ % (ksname, cfname)) - session.execute(""" + """ + % (ksname, cfname) + ) + session.execute( + """ CREATE INDEX myindex ON "%s"."%s" ("MyColumn") - """ % (ksname, cfname)) - session.execute(""" + """ + % (ksname, cfname) + ) + session.execute( + """ CREATE INDEX "AnotherIndex" ON "%s"."%s" ("B") - """ % (ksname, cfname)) + """ + % (ksname, cfname) + ) ksmeta = cluster.metadata.keyspaces[ksname] schema = ksmeta.export_as_string() @@ -1239,8 +1475,14 @@ def test_case_sensitivity(self): assert '"MyColumn" int' in schema assert 'PRIMARY KEY (k, "A")' in schema assert 'WITH CLUSTERING ORDER BY ("A" DESC)' in schema - assert 'CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")' in schema - assert 'CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")' in schema + assert ( + 'CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")' + in schema + ) + assert ( + 'CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")' + in schema + ) cluster.shutdown() def test_already_exists_exceptions(self): @@ -1251,41 +1493,43 @@ def test_already_exists_exceptions(self): cluster = TestCluster() session = cluster.connect() - ksname = 'test3rf' - cfname = 'test' + ksname = "test3rf" + cfname = "test" - ddl = ''' + ddl = """ CREATE KEYSPACE %s - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}""" with pytest.raises(AlreadyExists): session.execute(ddl % ksname) - ddl = ''' + ddl = """ CREATE TABLE %s.%s ( k int PRIMARY KEY, - v int )''' + v int )""" with pytest.raises(AlreadyExists): session.execute(ddl % (ksname, cfname)) cluster.shutdown() @local - @pytest.mark.xfail(reason='AssertionError: \'RAC1\' != \'r1\' - probably a bug in driver or in Scylla') + @pytest.mark.xfail( + reason="AssertionError: 'RAC1' != 'r1' - probably a bug in driver or in Scylla" + ) def test_replicas(self): """ Ensure cluster.metadata.get_replicas return correctly when not attached to keyspace """ if murmur3 is None: - raise unittest.SkipTest('the murmur3 extension is not available') + raise unittest.SkipTest("the murmur3 extension is not available") cluster = TestCluster() - assert cluster.metadata.get_replicas('test3rf', 'key') == [] + assert cluster.metadata.get_replicas("test3rf", "key") == [] - cluster.connect('test3rf') + cluster.connect("test3rf") - assert list(cluster.metadata.get_replicas('test3rf', b'key')) != [] - host = list(cluster.metadata.get_replicas('test3rf', b'key'))[0] - assert host.datacenter == 'dc1' - assert host.rack == 'r1' + assert list(cluster.metadata.get_replicas("test3rf", b"key")) != [] + host = list(cluster.metadata.get_replicas("test3rf", b"key"))[0] + assert host.datacenter == "dc1" + assert host.rack == "r1" cluster.shutdown() def test_token_map(self): @@ -1294,18 +1538,22 @@ def test_token_map(self): """ cluster = TestCluster() - cluster.connect('test3rf') + cluster.connect("test3rf") ring = cluster.metadata.token_map.ring - owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring) + owners = list( + cluster.metadata.token_map.token_to_host_owner[token] for token in ring + ) get_replicas = cluster.metadata.token_map.get_replicas - for ksname in ('test1rf', 'test2rf', 'test3rf'): + for ksname in ("test1rf", "test2rf", "test3rf"): assert list(get_replicas(ksname, ring[0])) != [] for i, token in enumerate(ring): - assert set(get_replicas('test3rf', token)) == set(owners) - assert set(get_replicas('test2rf', token)) == set([owners[i], owners[(i + 1) % 3]]) - assert set(get_replicas('test1rf', token)) == set([owners[i]]) + assert set(get_replicas("test3rf", token)) == set(owners) + assert set(get_replicas("test2rf", token)) == set( + [owners[i], owners[(i + 1) % 3]] + ) + assert set(get_replicas("test1rf", token)) == set([owners[i]]) cluster.shutdown() @@ -1313,6 +1561,7 @@ class TokenMetadataTest(unittest.TestCase): """ Test of TokenMap creation and other behavior. """ + @local def test_token(self): expected_node_count = len(get_cluster().nodes) @@ -1334,35 +1583,38 @@ class TestMetadataTimeout: "opts, expected_query_chunk", [ ( - {"metadata_request_timeout": None}, - # Should be borrowed from control_connection_timeout - "USING TIMEOUT 2000ms" - ), - ( - {"metadata_request_timeout": 0.0}, - False + {"metadata_request_timeout": None}, + # Should be borrowed from control_connection_timeout + "USING TIMEOUT 2000ms", ), + ({"metadata_request_timeout": 0.0}, False), + ({"metadata_request_timeout": 4.0}, "USING TIMEOUT 4000ms"), ( - {"metadata_request_timeout": 4.0}, - "USING TIMEOUT 4000ms" + {"metadata_request_timeout": None, "control_connection_timeout": None}, + False, ), - ( - {"metadata_request_timeout": None, "control_connection_timeout": None}, - False, - ) ], - ids=["default", "zero", "4s", "both none"] + ids=["default", "zero", "4s", "both none"], ) def test_timeout(self, opts, expected_query_chunk): cluster = TestCluster(**opts) stmts = [] class ConnectionWrapper(cluster.connection_class): - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, - decoder=ProtocolHandler.decode_message, result_metadata=None): + def send_msg( + self, + msg, + request_id, + cb, + encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, + result_metadata=None, + ): if isinstance(msg, QueryMessage): stmts.append(msg.query) - return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata) + return super(ConnectionWrapper, self).send_msg( + msg, request_id, cb, encoder, decoder, result_metadata + ) cluster.connection_class = ConnectionWrapper s = cluster.connect() @@ -1373,26 +1625,34 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, if "SELECT now() FROM system.local WHERE key='local'" in stmt: continue if expected_query_chunk: - assert expected_query_chunk in stmt, f"query `{stmt}` does not contain `{expected_query_chunk}`" + assert expected_query_chunk in stmt, ( + f"query `{stmt}` does not contain `{expected_query_chunk}`" + ) else: - assert 'USING TIMEOUT' not in stmt, f"query `{stmt}` should not contain `USING TIMEOUT`" + assert "USING TIMEOUT" not in stmt, ( + f"query `{stmt}` should not contain `USING TIMEOUT`" + ) class KeyspaceAlterMetadata(unittest.TestCase): """ Test verifies that table metadata is preserved on keyspace alter """ + def setUp(self): self.cluster = TestCluster() self.session = self.cluster.connect() name = self._testMethodName.lower() - crt_ks = ''' - CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} AND durable_writes = true''' % name + crt_ks = ( + """ + CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} AND durable_writes = true""" + % name + ) self.session.execute(crt_ks) def tearDown(self): name = self._testMethodName.lower() - self.session.execute('DROP KEYSPACE %s' % name) + self.session.execute("DROP KEYSPACE %s" % name) self.cluster.shutdown() def test_keyspace_alter(self): @@ -1407,20 +1667,19 @@ def test_keyspace_alter(self): """ name = self._testMethodName.lower() - self.session.execute('CREATE TABLE %s.d (d INT PRIMARY KEY)' % name) + self.session.execute("CREATE TABLE %s.d (d INT PRIMARY KEY)" % name) original_keyspace_meta = self.cluster.metadata.keyspaces[name] assert original_keyspace_meta.durable_writes == True assert len(original_keyspace_meta.tables) == 1 - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name) + self.session.execute("ALTER KEYSPACE %s WITH durable_writes = false" % name) new_keyspace_meta = self.cluster.metadata.keyspaces[name] assert original_keyspace_meta != new_keyspace_meta assert new_keyspace_meta.durable_writes == False class IndexMapTests(unittest.TestCase): - - keyspace_name = 'index_map_tests' + keyspace_name = "index_map_tests" @property def table_name(self): @@ -1438,7 +1697,9 @@ def setup_class(cls): """ CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}; - """ % cls.keyspace_name) + """ + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) except Exception: cls.cluster.shutdown() @@ -1452,7 +1713,9 @@ def teardown_class(cls): cls.cluster.shutdown() def create_basic_table(self): - self.session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int)" % self.table_name) + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, a int)" % self.table_name + ) def drop_basic_table(self): self.session.execute("DROP TABLE %s" % self.table_name) @@ -1462,10 +1725,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert 'a_idx' not in ks_meta.indexes - assert 'b_idx' not in ks_meta.indexes - assert 'a_idx' not in table_meta.indexes - assert 'b_idx' not in table_meta.indexes + assert "a_idx" not in ks_meta.indexes + assert "b_idx" not in ks_meta.indexes + assert "a_idx" not in table_meta.indexes + assert "b_idx" not in table_meta.indexes self.session.execute("CREATE INDEX a_idx ON %s (a)" % self.table_name) self.session.execute("ALTER TABLE %s ADD b int" % self.table_name) @@ -1473,10 +1736,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert isinstance(ks_meta.indexes['a_idx'], IndexMetadata) - assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) - assert isinstance(table_meta.indexes['a_idx'], IndexMetadata) - assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) + assert isinstance(ks_meta.indexes["a_idx"], IndexMetadata) + assert isinstance(ks_meta.indexes["b_idx"], IndexMetadata) + assert isinstance(table_meta.indexes["a_idx"], IndexMetadata) + assert isinstance(table_meta.indexes["b_idx"], IndexMetadata) # both indexes updated when index dropped self.session.execute("DROP INDEX a_idx") @@ -1486,28 +1749,30 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - assert 'a_idx' not in ks_meta.indexes - assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) - assert 'a_idx' not in table_meta.indexes - assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) + assert "a_idx" not in ks_meta.indexes + assert isinstance(ks_meta.indexes["b_idx"], IndexMetadata) + assert "a_idx" not in table_meta.indexes + assert isinstance(table_meta.indexes["b_idx"], IndexMetadata) # keyspace index updated when table dropped self.drop_basic_table() ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert self.table_name not in ks_meta.tables - assert 'a_idx' not in ks_meta.indexes - assert 'b_idx' not in ks_meta.indexes + assert "a_idx" not in ks_meta.indexes + assert "b_idx" not in ks_meta.indexes def test_index_follows_alter(self): self.create_basic_table() - idx = self.table_name + '_idx' + idx = self.table_name + "_idx" self.session.execute("CREATE INDEX %s ON %s (a)" % (idx, self.table_name)) ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] assert isinstance(ks_meta.indexes[idx], IndexMetadata) assert isinstance(table_meta.indexes[idx], IndexMetadata) - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) old_meta = ks_meta ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert ks_meta is not old_meta @@ -1516,6 +1781,7 @@ def test_index_follows_alter(self): assert isinstance(table_meta.indexes[idx], IndexMetadata) self.drop_basic_table() + @requires_java_udf class FunctionTest(unittest.TestCase): """ @@ -1528,7 +1794,9 @@ def setUp(self): """ if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Function metadata requires native protocol version 4+") + raise unittest.SkipTest( + "Function metadata requires native protocol version 4+" + ) @property def function_name(self): @@ -1540,10 +1808,17 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.execute( + "CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) - cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions - cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates + cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[ + cls.keyspace_name + ].functions + cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[ + cls.keyspace_name + ].aggregates @classmethod def teardown_class(cls): @@ -1552,7 +1827,6 @@ def teardown_class(cls): cls.cluster.shutdown() class Verified(object): - def __init__(self, test_case, meta_class, element_meta, **function_kwargs): self.test_case = test_case self.function_kwargs = dict(function_kwargs) @@ -1572,38 +1846,47 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): tc = self.test_case - tc.session.execute("DROP %s %s.%s" % (self.meta_class.__name__, tc.keyspace_name, self.signature)) + tc.session.execute( + "DROP %s %s.%s" + % (self.meta_class.__name__, tc.keyspace_name, self.signature) + ) assert self.signature not in self.element_meta @property def signature(self): - return SignatureDescriptor.format_signature(self.function_kwargs['name'], - self.function_kwargs['argument_types']) + return SignatureDescriptor.format_signature( + self.function_kwargs["name"], self.function_kwargs["argument_types"] + ) class VerifiedFunction(Verified): def __init__(self, test_case, **kwargs): - super(FunctionTest.VerifiedFunction, self).__init__(test_case, Function, test_case.keyspace_function_meta, **kwargs) + super(FunctionTest.VerifiedFunction, self).__init__( + test_case, Function, test_case.keyspace_function_meta, **kwargs + ) class VerifiedAggregate(Verified): def __init__(self, test_case, **kwargs): - super(FunctionTest.VerifiedAggregate, self).__init__(test_case, Aggregate, test_case.keyspace_aggregate_meta, **kwargs) + super(FunctionTest.VerifiedAggregate, self).__init__( + test_case, Aggregate, test_case.keyspace_aggregate_meta, **kwargs + ) @requires_java_udf class FunctionMetadata(FunctionTest): - def make_function_kwargs(self, called_on_null=True): - return {'keyspace': self.keyspace_name, - 'name': self.function_name, - 'argument_types': ['double', 'int'], - 'argument_names': ['d', 'i'], - 'return_type': 'double', - 'language': 'java', - 'body': 'return new Double(0.0);', - 'called_on_null_input': called_on_null, - 'deterministic': False, - 'monotonic': False, - 'monotonic_on': []} + return { + "keyspace": self.keyspace_name, + "name": self.function_name, + "argument_types": ["double", "int"], + "argument_names": ["d", "i"], + "return_type": "double", + "language": "java", + "body": "return new Double(0.0);", + "called_on_null_input": called_on_null, + "deterministic": False, + "monotonic": False, + "monotonic_on": [], + } def test_functions_after_udt(self): """ @@ -1629,15 +1912,19 @@ def test_functions_after_udt(self): assert self.function_name not in self.keyspace_function_meta - udt_name = 'udtx' + udt_name = "udtx" self.session.execute("CREATE TYPE %s (x int)" % udt_name) with self.VerifiedFunction(self, **self.make_function_kwargs()): # udts must come before functions in keyspace dump - keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + keyspace_cql = self.cluster.metadata.keyspaces[ + self.keyspace_name + ].export_as_string() type_idx = keyspace_cql.rfind("CREATE TYPE") func_idx = keyspace_cql.find("CREATE FUNCTION") - assert -1 not in (type_idx, func_idx), "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert -1 not in (type_idx, func_idx), ( + "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql + ) assert func_idx > type_idx def test_function_same_name_diff_types(self): @@ -1656,16 +1943,19 @@ def test_function_same_name_diff_types(self): # Create a function kwargs = self.make_function_kwargs() with self.VerifiedFunction(self, **kwargs): - # another function: same name, different type sig. - assert len(kwargs['argument_types']) > 1 - assert len(kwargs['argument_names']) > 1 - kwargs['argument_types'] = kwargs['argument_types'][:1] - kwargs['argument_names'] = kwargs['argument_names'][:1] + assert len(kwargs["argument_types"]) > 1 + assert len(kwargs["argument_names"]) > 1 + kwargs["argument_types"] = kwargs["argument_types"][:1] + kwargs["argument_names"] = kwargs["argument_names"][:1] # Ensure they are surfaced separately with self.VerifiedFunction(self, **kwargs): - functions = [f for f in self.keyspace_function_meta.values() if f.name == self.function_name] + functions = [ + f + for f in self.keyspace_function_meta.values() + if f.name == self.function_name + ] assert len(functions) == 2 assert functions[0].argument_types != functions[1].argument_types @@ -1681,14 +1971,16 @@ def test_function_no_parameters(self): @test_category function """ kwargs = self.make_function_kwargs() - kwargs['argument_types'] = [] - kwargs['argument_names'] = [] - kwargs['return_type'] = 'bigint' - kwargs['body'] = 'return System.currentTimeMillis() / 1000L;' + kwargs["argument_types"] = [] + kwargs["argument_names"] = [] + kwargs["return_type"] = "bigint" + kwargs["body"] = "return System.currentTimeMillis() / 1000L;" with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*%s\(\) .*' % kwargs['name']) + assertRegex( + fn_meta.as_cql_query(), r"CREATE FUNCTION.*%s\(\) .*" % kwargs["name"] + ) def test_functions_follow_keyspace_alter(self): """ @@ -1707,7 +1999,9 @@ def test_functions_follow_keyspace_alter(self): # Create function with self.VerifiedFunction(self, **self.make_function_kwargs()): original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) # After keyspace alter ensure that we maintain function equality. try: @@ -1715,7 +2009,9 @@ def test_functions_follow_keyspace_alter(self): assert original_keyspace_meta != new_keyspace_meta assert original_keyspace_meta.functions is new_keyspace_meta.functions finally: - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = true" % self.keyspace_name + ) def test_function_cql_called_on_null(self): """ @@ -1733,20 +2029,25 @@ def test_function_cql_called_on_null(self): """ kwargs = self.make_function_kwargs() - kwargs['called_on_null_input'] = True + kwargs["called_on_null_input"] = True with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*') + assertRegex( + fn_meta.as_cql_query(), + r"CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*", + ) - kwargs['called_on_null_input'] = False + kwargs["called_on_null_input"] = False with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*') + assertRegex( + fn_meta.as_cql_query(), + r"CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*", + ) @requires_java_udf class AggregateMetadata(FunctionTest): - @classmethod def setup_class(cls): if PROTOCOL_VERSION >= 4: @@ -1778,16 +2079,20 @@ def setup_class(cls): cls.session.execute("INSERT INTO t (k,v) VALUES (%s, %s)", (x, x)) cls.session.execute("INSERT INTO t (k) VALUES (%s)", (4,)) - def make_aggregate_kwargs(self, state_func, state_type, final_func=None, init_cond=None): - return {'keyspace': self.keyspace_name, - 'name': self.function_name + '_aggregate', - 'argument_types': ['int'], - 'state_func': state_func, - 'state_type': state_type, - 'final_func': final_func, - 'initial_condition': init_cond, - 'return_type': "does not matter for creation", - 'deterministic': False} + def make_aggregate_kwargs( + self, state_func, state_type, final_func=None, init_cond=None + ): + return { + "keyspace": self.keyspace_name, + "name": self.function_name + "_aggregate", + "argument_types": ["int"], + "state_func": state_func, + "state_type": state_type, + "final_func": final_func, + "initial_condition": init_cond, + "return_type": "does not matter for creation", + "deterministic": False, + } def test_return_type_meta(self): """ @@ -1803,8 +2108,10 @@ def test_return_type_meta(self): @test_category aggregate """ - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='1')) as va: - assert self.keyspace_aggregate_meta[va.signature].return_type == 'int' + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond="1") + ) as va: + assert self.keyspace_aggregate_meta[va.signature].return_type == "int" def test_init_cond(self): """ @@ -1831,27 +2138,59 @@ def test_init_cond(self): # int32 for init_cond in (-1, 0, 1): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond=cql_init)) as va: - sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name']).one().sum + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond=cql_init) + ) as va: + sum_res = ( + s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs["name"]) + .one() + .sum + ) assert sum_res == int(init_cond) + sum(expected_values) # list - for init_cond in ([], ['1', '2']): + for init_cond in ([], ["1", "2"]): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list', init_cond=cql_init)) as va: - list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name']).one().list_res - assertListEqual(list_res[:len(init_cond)], init_cond) - assert set(i for i in list_res[len(init_cond):]) == set(str(i) for i in expected_values) + with self.VerifiedAggregate( + self, + **self.make_aggregate_kwargs( + "extend_list", "list", init_cond=cql_init + ), + ) as va: + list_res = ( + s.execute( + "SELECT %s(v) AS list_res FROM t" % va.function_kwargs["name"] + ) + .one() + .list_res + ) + assertListEqual(list_res[: len(init_cond)], init_cond) + assert set(i for i in list_res[len(init_cond) :]) == set( + str(i) for i in expected_values + ) # map expected_map_values = dict((i, i) for i in expected_values) expected_key_set = set(expected_values) for init_cond in ({}, {1: 2, 3: 4}, {5: 5}): cql_init = encoder.cql_encode_all_types(init_cond) - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map', init_cond=cql_init)) as va: - map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name']).one().map_res + with self.VerifiedAggregate( + self, + **self.make_aggregate_kwargs( + "update_map", "map", init_cond=cql_init + ), + ) as va: + map_res = ( + s.execute( + "SELECT %s(v) AS map_res FROM t" % va.function_kwargs["name"] + ) + .one() + .map_res + ) assert expected_map_values.items() <= map_res.items() - init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set) + init_not_updated = dict( + (k, init_cond[k]) for k in set(init_cond) - expected_key_set + ) assert init_not_updated.items() <= map_res.items() c.shutdown() @@ -1870,11 +2209,17 @@ def test_aggregates_after_functions(self): """ # functions must come before functions in keyspace dump - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list')): - keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("extend_list", "list") + ): + keyspace_cql = self.cluster.metadata.keyspaces[ + self.keyspace_name + ].export_as_string() func_idx = keyspace_cql.find("CREATE FUNCTION") aggregate_idx = keyspace_cql.rfind("CREATE AGGREGATE") - assert -1 not in (aggregate_idx, func_idx), "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert -1 not in (aggregate_idx, func_idx), ( + "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql + ) assert aggregate_idx > func_idx def test_same_name_diff_types(self): @@ -1890,12 +2235,16 @@ def test_same_name_diff_types(self): @test_category function """ - kwargs = self.make_aggregate_kwargs('sum_int', 'int', init_cond='0') + kwargs = self.make_aggregate_kwargs("sum_int", "int", init_cond="0") with self.VerifiedAggregate(self, **kwargs): - kwargs['state_func'] = 'sum_int_two' - kwargs['argument_types'] = ['int', 'int'] + kwargs["state_func"] = "sum_int_two" + kwargs["argument_types"] = ["int", "int"] with self.VerifiedAggregate(self, **kwargs): - aggregates = [a for a in self.keyspace_aggregate_meta.values() if a.name == kwargs['name']] + aggregates = [ + a + for a in self.keyspace_aggregate_meta.values() + if a.name == kwargs["name"] + ] assert len(aggregates) == 2 assert aggregates[0].argument_types != aggregates[1].argument_types @@ -1913,15 +2262,21 @@ def test_aggregates_follow_keyspace_alter(self): @test_category function """ - with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='0')): + with self.VerifiedAggregate( + self, **self.make_aggregate_kwargs("sum_int", "int", init_cond="0") + ): original_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = false" % self.keyspace_name + ) try: new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] assert original_keyspace_meta != new_keyspace_meta assert original_keyspace_meta.aggregates is new_keyspace_meta.aggregates finally: - self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) + self.session.execute( + "ALTER KEYSPACE %s WITH durable_writes = true" % self.keyspace_name + ) def test_cql_optional_params(self): """ @@ -1937,53 +2292,57 @@ def test_cql_optional_params(self): @test_category function """ - kwargs = self.make_aggregate_kwargs('extend_list', 'list') + kwargs = self.make_aggregate_kwargs("extend_list", "list") encoder = Encoder() # no initial condition, final func - assert kwargs['initial_condition'] is None - assert kwargs['final_func'] is None + assert kwargs["initial_condition"] is None + assert kwargs["final_func"] is None with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] assert meta.initial_condition is None assert meta.final_func is None cql = meta.as_cql_query() - assert cql.find('INITCOND') == -1 - assert cql.find('FINALFUNC') == -1 + assert cql.find("INITCOND") == -1 + assert cql.find("FINALFUNC") == -1 # initial condition, no final func - kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) + kwargs["initial_condition"] = encoder.cql_encode_all_types(["init", "cond"]) with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - assert meta.initial_condition == kwargs['initial_condition'] + assert meta.initial_condition == kwargs["initial_condition"] assert meta.final_func is None cql = meta.as_cql_query() - search_string = "INITCOND %s" % kwargs['initial_condition'] - assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) - assert cql.find('FINALFUNC') == -1 + search_string = "INITCOND %s" % kwargs["initial_condition"] + assert cql.find(search_string) > 0, ( + '"%s" search string not found in cql:\n%s' % (search_string, cql) + ) + assert cql.find("FINALFUNC") == -1 # no initial condition, final func - kwargs['initial_condition'] = None - kwargs['final_func'] = 'List_As_String' + kwargs["initial_condition"] = None + kwargs["final_func"] = "List_As_String" with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] assert meta.initial_condition is None - assert meta.final_func == kwargs['final_func'] + assert meta.final_func == kwargs["final_func"] cql = meta.as_cql_query() - assert cql.find('INITCOND') == -1 - search_string = 'FINALFUNC "%s"' % kwargs['final_func'] - assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) + assert cql.find("INITCOND") == -1 + search_string = 'FINALFUNC "%s"' % kwargs["final_func"] + assert cql.find(search_string) > 0, ( + '"%s" search string not found in cql:\n%s' % (search_string, cql) + ) # both - kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) - kwargs['final_func'] = 'List_As_String' + kwargs["initial_condition"] = encoder.cql_encode_all_types(["init", "cond"]) + kwargs["final_func"] = "List_As_String" with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - assert meta.initial_condition == kwargs['initial_condition'] - assert meta.final_func == kwargs['final_func'] + assert meta.initial_condition == kwargs["initial_condition"] + assert meta.final_func == kwargs["final_func"] cql = meta.as_cql_query() - init_cond_idx = cql.find("INITCOND %s" % kwargs['initial_condition']) - final_func_idx = cql.find('FINALFUNC "%s"' % kwargs['final_func']) + init_cond_idx = cql.find("INITCOND %s" % kwargs["initial_condition"]) + final_func_idx = cql.find('FINALFUNC "%s"' % kwargs["final_func"]) assert -1 not in (init_cond_idx, final_func_idx) assert init_cond_idx > final_func_idx @@ -2007,7 +2366,10 @@ def setup_class(cls): cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.execute( + "CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" + % cls.keyspace_name + ) cls.session.set_keyspace(cls.keyspace_name) connection = cls.cluster.control_connection._connection @@ -2025,34 +2387,56 @@ def teardown_class(cls): drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster) def test_bad_keyspace(self): - with patch.object(self.parser_class, '_build_keyspace_metadata_internal', side_effect=self.BadMetaException): + with patch.object( + self.parser_class, + "_build_keyspace_metadata_internal", + side_effect=self.BadMetaException, + ): self.cluster.refresh_keyspace_metadata(self.keyspace_name) m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() def test_bad_table(self): - self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) - with patch.object(self.parser_class, '_build_column_metadata', side_effect=self.BadMetaException): + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, v int)" % self.function_name + ) + with patch.object( + self.parser_class, + "_build_column_metadata", + side_effect=self.BadMetaException, + ): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) - m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_name + ] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() def test_bad_index(self): - self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) - self.session.execute('CREATE INDEX ON %s(v)' % self.function_name) - with patch.object(self.parser_class, '_build_index_metadata', side_effect=self.BadMetaException): + self.session.execute( + "CREATE TABLE %s (k int PRIMARY KEY, v int)" % self.function_name + ) + self.session.execute("CREATE INDEX ON %s(v)" % self.function_name) + with patch.object( + self.parser_class, + "_build_index_metadata", + side_effect=self.BadMetaException, + ): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) - m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] + m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + self.function_name + ] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @greaterthancass20 def test_bad_user_type(self): - self.session.execute('CREATE TYPE %s (i int, d double)' % self.function_name) - with patch.object(self.parser_class, '_build_user_type', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + self.session.execute("CREATE TYPE %s (i int, d double)" % self.function_name) + with patch.object( + self.parser_class, "_build_user_type", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @@ -2060,18 +2444,23 @@ def test_bad_user_type(self): @greaterthancass21 @requires_java_udf def test_bad_user_function(self): - self.session.execute("""CREATE FUNCTION IF NOT EXISTS %s (key int, val int) + self.session.execute( + """CREATE FUNCTION IF NOT EXISTS %s (key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE java AS 'return key + val;';""" % self.function_name) - - #We need to patch as well the reconnect function because after patching the _build_function - #there will an Error refreshing schema which will trigger a reconnection. If this happened - #in a timely manner in the call self.cluster.refresh_schema_metadata() it would return an exception - #due to that a connection would be closed - with patch.object(self.cluster.control_connection, 'reconnect'): - with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + LANGUAGE java AS 'return key + val;';""" + % self.function_name + ) + + # We need to patch as well the reconnect function because after patching the _build_function + # there will an Error refreshing schema which will trigger a reconnection. If this happened + # in a timely manner in the call self.cluster.refresh_schema_metadata() it would return an exception + # due to that a connection would be closed + with patch.object(self.cluster.control_connection, "reconnect"): + with patch.object( + self.parser_class, "_build_function", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() @@ -2083,21 +2472,25 @@ def test_bad_user_aggregate(self): RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS 'return key + val;';""") - self.session.execute("""CREATE AGGREGATE %s(int) + self.session.execute( + """CREATE AGGREGATE %s(int) SFUNC sum_int STYPE int - INITCOND 0""" % self.function_name) - #We have the same issue here as in test_bad_user_function - with patch.object(self.cluster.control_connection, 'reconnect'): - with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + INITCOND 0""" + % self.function_name + ) + # We have the same issue here as in test_bad_user_function + with patch.object(self.cluster.control_connection, "reconnect"): + with patch.object( + self.parser_class, "_build_aggregate", side_effect=self.BadMetaException + ): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] assert m._exc_info[0] is self.BadMetaException assert "/*\nWarning:" in m.export_as_string() class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): - @requires_composite_type def test_dct_alias(self): """ @@ -2111,18 +2504,24 @@ def test_dct_alias(self): @test_category metadata """ - self.session.execute("CREATE TABLE {0}.{1} (" - "k int PRIMARY KEY," - "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," - "c2 Text)".format(self.ks_name, self.function_table_name)) - dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) + self.session.execute( + "CREATE TABLE {0}.{1} (" + "k int PRIMARY KEY," + "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," + "c2 Text)".format(self.ks_name, self.function_table_name) + ) + dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get( + self.function_table_name + ) # Format can very slightly between versions, strip out whitespace for consistency sake table_text = dct_table.as_cql_query().replace(" ", "") dynamic_type_text = "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" assert "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" in table_text # Types within in the composite can come out in random order, so grab the type definition and find each one - type_definition_start = table_text.index("(", table_text.find(dynamic_type_text)) + type_definition_start = table_text.index( + "(", table_text.find(dynamic_type_text) + ) type_definition_end = table_text.index(")") type_definition_text = table_text[type_definition_start:type_definition_end] assert "s=>org.apache.cassandra.db.marshal.UTF8Type" in type_definition_text @@ -2131,19 +2530,27 @@ def test_dct_alias(self): @greaterthanorequalcass30 class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - self.session.execute("CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format( + self.keyspace_name, self.function_table_name + ) + ) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c) " "WITH compaction = {{ 'class' : 'SizeTieredCompactionStrategy' }}".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) ) def tearDown(self): - self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) - self.session.execute("DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name) + ) + self.session.execute( + "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name) + ) def test_materialized_view_metadata_creation(self): """ @@ -2162,10 +2569,27 @@ def test_materialized_view_metadata_creation(self): """ assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].views - assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) - assert self.keyspace_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].keyspace_name - assert self.function_table_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].base_table_name + assert ( + self.keyspace_name + == self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .keyspace_name + ) + assert ( + self.function_table_name + == self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .base_table_name + ) def test_materialized_view_metadata_alter(self): """ @@ -2182,10 +2606,26 @@ def test_materialized_view_metadata_alter(self): @test_category metadata """ - assert "SizeTieredCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + assert ( + "SizeTieredCompactionStrategy" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .options["compaction"]["class"] + ) - self.session.execute("ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format(self.keyspace_name)) - assert "LeveledCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + self.session.execute( + "ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format( + self.keyspace_name + ) + ) + assert ( + "LeveledCompactionStrategy" + in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views["mv1"] + .options["compaction"]["class"] + ) def test_materialized_view_metadata_drop(self): """ @@ -2203,17 +2643,30 @@ def test_materialized_view_metadata_drop(self): @test_category metadata """ - self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) + self.session.execute( + "DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name) + ) - assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert ( + "mv1" + not in self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views + ) assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].views - assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assertDictEqual( + {}, + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables[self.function_table_name] + .views, + ) assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c)".format( - self.keyspace_name, self.function_table_name) + self.keyspace_name, self.function_table_name + ) ) @@ -2249,36 +2702,40 @@ def test_create_view_metadata(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] - mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['monthlyhigh'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views["monthlyhigh"] assert score_table.views["monthlyhigh"] is not None assert len(score_table.views) is not None, 1 # Make sure user is a partition key, and not null assert len(score_table.partition_key) == 1 - assert score_table.columns['user'] is not None - assert score_table.columns['user'], score_table.partition_key[0] + assert score_table.columns["user"] is not None + assert score_table.columns["user"], score_table.partition_key[0] # Validate clustering keys assert len(score_table.clustering_key) == 4 - assert score_table.columns['game'] is not None - assert score_table.columns['game'], score_table.clustering_key[0] + assert score_table.columns["game"] is not None + assert score_table.columns["game"], score_table.clustering_key[0] - assert score_table.columns['year'] is not None - assert score_table.columns['year'], score_table.clustering_key[1] + assert score_table.columns["year"] is not None + assert score_table.columns["year"], score_table.clustering_key[1] - assert score_table.columns['month'] is not None - assert score_table.columns['month'], score_table.clustering_key[2] + assert score_table.columns["month"] is not None + assert score_table.columns["month"], score_table.clustering_key[2] - assert score_table.columns['day'] is not None - assert score_table.columns['day'], score_table.clustering_key[3] + assert score_table.columns["day"] is not None + assert score_table.columns["day"], score_table.clustering_key[3] - assert score_table.columns['score'] is not None + assert score_table.columns["score"] is not None # Validate basic mv information assert mv.keyspace_name == self.keyspace_name @@ -2292,17 +2749,17 @@ def test_create_view_metadata(self): game_column = mv_columns[0] assert game_column is not None - assert game_column.name == 'game' + assert game_column.name == "game" assert game_column == mv.partition_key[0] year_column = mv_columns[1] assert year_column is not None - assert year_column.name == 'year' + assert year_column.name == "year" assert year_column == mv.partition_key[1] month_column = mv_columns[2] assert month_column is not None - assert month_column.name == 'month' + assert month_column.name == "month" assert month_column == mv.partition_key[2] def compare_columns(a, b, name): @@ -2314,13 +2771,13 @@ def compare_columns(a, b, name): assert a.is_reversed == b.is_reversed score_column = mv_columns[3] - compare_columns(score_column, mv.clustering_key[0], 'score') + compare_columns(score_column, mv.clustering_key[0], "score") user_column = mv_columns[4] - compare_columns(user_column, mv.clustering_key[1], 'user') + compare_columns(user_column, mv.clustering_key[1], "user") day_column = mv_columns[5] - compare_columns(day_column, mv.clustering_key[2], 'day') + compare_columns(day_column, mv.clustering_key[2], "day") def test_base_table_column_addition_mv(self): """ @@ -2351,44 +2808,60 @@ def test_base_table_column_addition_mv(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) create_mv_alltime = """CREATE MATERIALIZED VIEW {0}.alltimehigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) self.session.execute(create_mv_alltime) - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] assert score_table.views["monthlyhigh"] is not None assert score_table.views["alltimehigh"] is not None assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 - insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format((self.keyspace_name)) + insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format( + (self.keyspace_name) + ) self.session.execute(insert_fouls) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 - score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] + score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables[ + "scores" + ] assert "fouls" in score_table.columns # This is a workaround for mv notifications being separate from base table schema responses. # This maybe fixed with future protocol changes for i in range(10): - mv_alltime = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"] - if("fouls" in mv_alltime.columns): + mv_alltime = self.cluster.metadata.keyspaces[self.keyspace_name].views[ + "alltimehigh" + ] + if "fouls" in mv_alltime.columns: break - time.sleep(.2) + time.sleep(0.2) assert "fouls" in mv_alltime.columns - mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - assert mv_alltime_fouls_comumn.cql_type == 'int' + mv_alltime_fouls_comumn = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .views["alltimehigh"] + .columns["fouls"] + ) + assert mv_alltime_fouls_comumn.cql_type == "int" @lessthancass30 def test_base_table_type_alter_mv(self): @@ -2423,25 +2896,37 @@ def test_base_table_type_alter_mv(self): SELECT game, year, month, score, user, day FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format( + self.keyspace_name + ) self.session.execute(create_mv) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 - alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format((self.keyspace_name)) + alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format( + (self.keyspace_name) + ) self.session.execute(alter_scores) assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 - score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - assert score_column.cql_type == 'blob' + score_column = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .tables["scores"] + .columns["score"] + ) + assert score_column.cql_type == "blob" # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): - score_mv_column = self.cluster.metadata.keyspaces[self.keyspace_name].views["monthlyhigh"].columns['score'] + score_mv_column = ( + self.cluster.metadata.keyspaces[self.keyspace_name] + .views["monthlyhigh"] + .columns["score"] + ) if "blob" == score_mv_column.cql_type: break - time.sleep(.2) + time.sleep(0.2) - assert score_mv_column.cql_type == 'blob' + assert score_mv_column.cql_type == "blob" def test_metadata_with_quoted_identifiers(self): """ @@ -2462,7 +2947,9 @@ def test_metadata_with_quoted_identifiers(self): "theKey" int, "the;Clustering" int, "the Value" int, - PRIMARY KEY ("theKey", "the;Clustering"))""".format(self.keyspace_name) + PRIMARY KEY ("theKey", "the;Clustering"))""".format( + self.keyspace_name + ) self.session.execute(create_table) @@ -2470,28 +2957,30 @@ def test_metadata_with_quoted_identifiers(self): SELECT "theKey", "the;Clustering", "the Value" FROM {0}.t1 WHERE "theKey" IS NOT NULL AND "the;Clustering" IS NOT NULL AND "the Value" IS NOT NULL - PRIMARY KEY ("theKey", "the;Clustering")""".format(self.keyspace_name) + PRIMARY KEY ("theKey", "the;Clustering")""".format( + self.keyspace_name + ) self.session.execute(create_mv) - t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['t1'] - mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] + t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables["t1"] + mv = self.cluster.metadata.keyspaces[self.keyspace_name].views["mv1"] assert t1_table.views["mv1"] is not None assert len(t1_table.views) is not None, 1 # Validate partition key, and not null assert len(t1_table.partition_key) == 1 - assert t1_table.columns['theKey'] is not None - assert t1_table.columns['theKey'], t1_table.partition_key[0] + assert t1_table.columns["theKey"] is not None + assert t1_table.columns["theKey"], t1_table.partition_key[0] # Validate clustering key column assert len(t1_table.clustering_key) == 1 - assert t1_table.columns['the;Clustering'] is not None - assert t1_table.columns['the;Clustering'], t1_table.clustering_key[0] + assert t1_table.columns["the;Clustering"] is not None + assert t1_table.columns["the;Clustering"], t1_table.clustering_key[0] # Validate regular column - assert t1_table.columns['the Value'] is not None + assert t1_table.columns["the Value"] is not None # Validate basic mv information assert mv.keyspace_name == self.keyspace_name @@ -2505,12 +2994,12 @@ def test_metadata_with_quoted_identifiers(self): theKey_column = mv_columns[0] assert theKey_column is not None - assert theKey_column.name == 'theKey' + assert theKey_column.name == "theKey" assert theKey_column == mv.partition_key[0] cluster_column = mv_columns[1] assert cluster_column is not None - assert cluster_column.name == 'the;Clustering' + assert cluster_column.name == "the;Clustering" assert cluster_column.name == mv.clustering_key[0].name assert cluster_column.table == mv.clustering_key[0].table assert cluster_column.is_static == mv.clustering_key[0].is_static @@ -2518,7 +3007,7 @@ def test_metadata_with_quoted_identifiers(self): value_column = mv_columns[2] assert value_column is not None - assert value_column.name == 'the Value' + assert value_column.name == "the Value" class GroupPerHost(BasicSharedKeyspaceUnitTestCase): @@ -2527,13 +3016,13 @@ def setUpClass(cls): cls.common_setup(rf=1, create_class_table=True) cls.table_two_pk = "table_with_two_pk" cls.session.execute( - ''' + """ CREATE TABLE {0}.{1} ( k_one int, k_two int, v int, PRIMARY KEY ((k_one, k_two)) - )'''.format(cls.ks_name, cls.table_two_pk) + )""".format(cls.ks_name, cls.table_two_pk) ) def test_group_keys_by_host(self): @@ -2548,7 +3037,9 @@ def test_group_keys_by_host(self): @test_category metadata """ stmt = """SELECT * FROM {}.{} - WHERE k_one = ? AND k_two = ? """.format(self.ks_name, self.table_two_pk) + WHERE k_one = ? AND k_two = ? """.format( + self.ks_name, self.table_two_pk + ) keys = ((1, 2), (2, 2), (2, 3), (3, 4)) self._assert_group_keys_by_host(keys, self.table_two_pk, stmt) @@ -2558,7 +3049,9 @@ def test_group_keys_by_host(self): self._assert_group_keys_by_host(keys, self.ks_name, stmt) def _assert_group_keys_by_host(self, keys, table_name, stmt): - keys_per_host = group_keys_by_replica(self.session, self.ks_name, table_name, keys) + keys_per_host = group_keys_by_replica( + self.session, self.ks_name, table_name, keys + ) assert NO_VALID_REPLICA not in keys_per_host prepared_stmt = self.session.prepare(stmt) From 879a840cb2d7d130b9dfb3f088233a8953e994f7 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:13:48 +0200 Subject: [PATCH 14/18] remove: dead unichr() code blocks guarded by Python 2 version check Two code blocks in test_validation.py were guarded by 'if sys.version_info < (3, 1):', making them unreachable on Python 3. The blocks used unichr() (a Python 2 builtin that does not exist in Python 3) and u'' string prefixes for unicode validation tests. In Python 3, chr() already returns a unicode character and all strings are unicode, so the adjacent chr(233) tests already cover the same functionality. Remove the dead blocks entirely. --- .../cqlengine/columns/test_validation.py | 350 ++++++++++-------- 1 file changed, 188 insertions(+), 162 deletions(-) diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index ebffc0666c..cf05c3d6f5 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -21,24 +21,46 @@ from packaging.version import Version from cassandra import InvalidRequest -from cassandra.cqlengine.columns import (TimeUUID, Ascii, Text, Integer, BigInt, - VarInt, DateTime, Date, UUID, Boolean, - Decimal, Inet, Time, UserDefinedType, - Map, List, Set, Tuple, Double, Duration) +from cassandra.cqlengine.columns import ( + TimeUUID, + Ascii, + Text, + Integer, + BigInt, + VarInt, + DateTime, + Date, + UUID, + Boolean, + Decimal, + Inet, + Time, + UserDefinedType, + Map, + List, + Set, + Tuple, + Double, + Duration, +) from cassandra.cqlengine.connection import execute from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.models import Model, ValidationError from cassandra.cqlengine.usertype import UserType from cassandra import util -from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 +from tests.integration import ( + PROTOCOL_VERSION, + CASSANDRA_VERSION, + greaterthanorequalcass30, + greaterthanorequalcass3_11, +) from tests.integration.cqlengine.base import BaseCassEngTestCase import pytest class TestDatetime(BaseCassEngTestCase): class DatetimeTest(Model): - test_id = Integer(primary_key=True) created_at = DateTime() @@ -60,25 +82,38 @@ def test_datetime_tzinfo_io(self): class TZ(tzinfo): def utcoffset(self, date_time): return timedelta(hours=-1) + def dst(self, date_time): return None now = datetime(1982, 1, 1, tzinfo=TZ()) dt = self.DatetimeTest.objects.create(test_id=1, created_at=now) dt2 = self.DatetimeTest.objects(test_id=1).first() - assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] + assert ( + dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] + ) @greaterthanorequalcass30 def test_datetime_date_support(self): today = date.today() self.DatetimeTest.objects.create(test_id=2, created_at=today) dt2 = self.DatetimeTest.objects(test_id=2).first() - assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat() + assert ( + dt2.created_at.isoformat() + == datetime(today.year, today.month, today.day).isoformat() + ) - result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() + result = ( + self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() + ) assert result.created_at == datetime.combine(today, datetime.min.time()) - result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first() + result = ( + self.DatetimeTest.objects.all() + .allow_filtering() + .filter(test_id=2, created_at=today) + .first() + ) assert result.created_at == datetime.combine(today, datetime.min.time()) def test_datetime_none(self): @@ -86,11 +121,11 @@ def test_datetime_none(self): dt2 = self.DatetimeTest.objects(test_id=3).first() assert dt2.created_at is None - dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at') + dts = self.DatetimeTest.objects.filter(test_id=3).values_list("created_at") assert dts[0][0] is None def test_datetime_invalid(self): - dt_value= 'INVALID' + dt_value = "INVALID" with pytest.raises(TypeError): self.DatetimeTest.objects.create(test_id=4, created_at=dt_value) @@ -98,7 +133,9 @@ def test_datetime_timestamp(self): dt_value = 1454520554 self.DatetimeTest.objects.create(test_id=5, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=5).first() - assert dt2.created_at == datetime.fromtimestamp(dt_value, tz=timezone.utc).replace(tzinfo=None) + assert dt2.created_at == datetime.fromtimestamp( + dt_value, tz=timezone.utc + ).replace(tzinfo=None) def test_datetime_large(self): dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000) @@ -132,7 +169,6 @@ def test_datetime_truncate_microseconds(self): class TestBoolDefault(BaseCassEngTestCase): class BoolDefaultValueTest(Model): - test_id = Integer(primary_key=True) stuff = Boolean(default=True) @@ -149,7 +185,6 @@ def test_default_is_set(self): class TestBoolValidation(BaseCassEngTestCase): class BoolValidationTest(Model): - test_id = Integer(primary_key=True) bool_column = Boolean() @@ -166,7 +201,6 @@ def test_validation_preserves_none(self): class TestVarInt(BaseCassEngTestCase): class VarIntTest(Model): - test_id = Integer(primary_key=True) bignum = VarInt(primary_key=True) @@ -181,7 +215,9 @@ def tearDownClass(cls): def test_varint_io(self): # TODO: this is a weird test. i changed the number from sys.maxint (which doesn't exist in python 3) # to the giant number below and it broken between runs. - long_int = 92834902384092834092384028340283048239048203480234823048230482304820348239 + long_int = ( + 92834902384092834092384028340283048239048203480234823048230482304820348239 + ) int1 = self.VarIntTest.objects.create(test_id=0, bignum=long_int) int2 = self.VarIntTest.objects(test_id=0).first() assert int1.bignum == int2.bignum @@ -190,7 +226,7 @@ def test_varint_io(self): self.VarIntTest.objects.create(test_id=0, bignum="not_a_number") -class DataType(): +class DataType: @classmethod def setUpClass(cls): if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): @@ -211,18 +247,26 @@ def tearDownClass(cls): def setUp(self): if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): - raise unittest.SkipTest("Protocol v4 datatypes " - "require native protocol 4+ and C* version >=3.0, " - "currently using protocol {0} and C* version {1}". - format(PROTOCOL_VERSION, CASSANDRA_VERSION)) + raise unittest.SkipTest( + "Protocol v4 datatypes " + "require native protocol 4+ and C* version >=3.0, " + "currently using protocol {0} and C* version {1}".format( + PROTOCOL_VERSION, CASSANDRA_VERSION + ) + ) def _check_value_is_correct_in_db(self, value): """ - Check that different ways of reading the value - from the model class give the same expected result + Check that different ways of reading the value + from the model class give the same expected result """ if value is None: - result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + result = ( + self.model_class.objects.all() + .allow_filtering() + .filter(test_id=0) + .first() + ) assert result.class_param is None result = self.model_class.objects(test_id=0).first() @@ -238,11 +282,21 @@ def _check_value_is_correct_in_db(self, value): assert isinstance(result.class_param, self.python_klass) assert result.class_param == value_to_compare - result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + result = ( + self.model_class.objects.all() + .allow_filtering() + .filter(test_id=0) + .first() + ) assert isinstance(result.class_param, self.python_klass) assert result.class_param == value_to_compare - result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first() + result = ( + self.model_class.objects.all() + .allow_filtering() + .filter(test_id=0, class_param=value) + .first() + ) assert isinstance(result.class_param, self.python_klass) assert result.class_param == value_to_compare @@ -272,29 +326,26 @@ def test_param_io(self): def test_param_none(self): """ - Test that None value is correctly written to the db - and then is correctly read + Test that None value is correctly written to the db + and then is correctly read """ self.model_class.objects.create(test_id=1, class_param=None) dt2 = self.model_class.objects(test_id=1).first() assert dt2.class_param is None - dts = self.model_class.objects(test_id=1).values_list('class_param') + dts = self.model_class.objects(test_id=1).values_list("class_param") assert dts[0][0] is None class TestDate(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): - cls.db_klass, cls.python_klass = ( - Date, - util.Date - ) + cls.db_klass, cls.python_klass = (Date, util.Date) cls.first_value, cls.second_value, cls.third_value = ( datetime.utcnow(), util.Date(datetime(1, 1, 1)), - datetime(1, 1, 2) + datetime(1, 1, 2), ) super(TestDate, cls).setUpClass() @@ -302,14 +353,11 @@ def setUpClass(cls): class TestTime(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): - cls.db_klass, cls.python_klass = ( - Time, - util.Time - ) + cls.db_klass, cls.python_klass = (Time, util.Time) cls.first_value, cls.second_value, cls.third_value = ( None, util.Time(time(2, 12, 7, 49)), - time(2, 12, 7, 50) + time(2, 12, 7, 50), ) super(TestTime, cls).setUpClass() @@ -317,14 +365,11 @@ def setUpClass(cls): class TestDateTime(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): - cls.db_klass, cls.python_klass = ( - DateTime, - datetime - ) + cls.db_klass, cls.python_klass = (DateTime, datetime) cls.first_value, cls.second_value, cls.third_value = ( datetime(2017, 4, 13, 18, 34, 24, 317000), datetime(1, 1, 1), - datetime(1, 1, 2) + datetime(1, 1, 2), ) super(TestDateTime, cls).setUpClass() @@ -332,31 +377,22 @@ def setUpClass(cls): class TestBoolean(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): - cls.db_klass, cls.python_klass = ( - Boolean, - bool - ) - cls.first_value, cls.second_value, cls.third_value = ( - None, - False, - True - ) + cls.db_klass, cls.python_klass = (Boolean, bool) + cls.first_value, cls.second_value, cls.third_value = (None, False, True) super(TestBoolean, cls).setUpClass() + @greaterthanorequalcass3_11 class TestDuration(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): # setUpClass is executed despite the whole class being skipped if CASSANDRA_VERSION >= Version("3.10"): - cls.db_klass, cls.python_klass = ( - Duration, - util.Duration - ) + cls.db_klass, cls.python_klass = (Duration, util.Duration) cls.first_value, cls.second_value, cls.third_value = ( util.Duration(0, 0, 0), util.Duration(1, 2, 3), - util.Duration(0, 0, 0) + util.Duration(0, 0, 0), ) super(TestDuration, cls).setUpClass() @@ -387,7 +423,7 @@ class TestUDT(DataType, BaseCassEngTestCase): def setUpClass(cls): if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): return - + cls.db_klass, cls.python_klass = UserDefinedType, User cls.first_value = User( age=1, @@ -395,7 +431,7 @@ def setUpClass(cls): map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, list_param=[datetime(1, 1, 2), datetime(1, 1, 3)], set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))), - tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4()) + tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4()), ) cls.second_value = User( @@ -404,7 +440,7 @@ def setUpClass(cls): map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, list_param=[datetime(1, 1, 2), datetime(1, 2, 3)], set_param=None, - tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4()) + tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4()), ) cls.third_value = User( @@ -413,7 +449,7 @@ def setUpClass(cls): map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))}, list_param=[datetime(1, 1, 2), datetime(1, 1, 4)], set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))), - tuple_param=(None, 3, False, None, 2.3214, uuid4()) + tuple_param=(None, 3, False, None, 2.3214, uuid4()), ) cls.model_class = UserModel @@ -422,7 +458,6 @@ def setUpClass(cls): class TestDecimal(BaseCassEngTestCase): class DecimalTest(Model): - test_id = Integer(primary_key=True) dec_val = Decimal() @@ -435,18 +470,17 @@ def tearDownClass(cls): drop_table(cls.DecimalTest) def test_decimal_io(self): - dt = self.DecimalTest.objects.create(test_id=0, dec_val=D('0.00')) + dt = self.DecimalTest.objects.create(test_id=0, dec_val=D("0.00")) dt2 = self.DecimalTest.objects(test_id=0).first() assert dt2.dec_val == dt.dec_val dt = self.DecimalTest.objects.create(test_id=0, dec_val=5) dt2 = self.DecimalTest.objects(test_id=0).first() - assert dt2.dec_val == D('5') + assert dt2.dec_val == D("5") class TestUUID(BaseCassEngTestCase): class UUIDTest(Model): - test_id = Integer(primary_key=True) a_uuid = UUID(default=uuid4()) @@ -480,7 +514,6 @@ def test_uuid_with_upcase(self): class TestTimeUUID(BaseCassEngTestCase): class TimeUUIDTest(Model): - test_id = Integer(primary_key=True) timeuuid = TimeUUID(default=uuid1()) @@ -505,80 +538,78 @@ def test_timeuuid_io(self): class TestInteger(BaseCassEngTestCase): class IntegerTest(Model): - - test_id = UUID(primary_key=True, default=lambda:uuid4()) + test_id = UUID(primary_key=True, default=lambda: uuid4()) value = Integer(default=0, required=True) def test_default_zero_fields_validate(self): - """ Tests that integer columns with a default value of 0 validate """ + """Tests that integer columns with a default value of 0 validate""" it = self.IntegerTest() it.validate() class TestBigInt(BaseCassEngTestCase): class BigIntTest(Model): - - test_id = UUID(primary_key=True, default=lambda:uuid4()) - value = BigInt(default=0, required=True) + test_id = UUID(primary_key=True, default=lambda: uuid4()) + value = BigInt(default=0, required=True) def test_default_zero_fields_validate(self): - """ Tests that bigint columns with a default value of 0 validate """ + """Tests that bigint columns with a default value of 0 validate""" it = self.BigIntTest() it.validate() class TestAscii(BaseCassEngTestCase): def test_min_length(self): - """ Test arbitrary minimal lengths requirements. """ + """Test arbitrary minimal lengths requirements.""" - Ascii(min_length=0).validate('') - Ascii(min_length=0, required=True).validate('') + Ascii(min_length=0).validate("") + Ascii(min_length=0, required=True).validate("") Ascii(min_length=0).validate(None) - Ascii(min_length=0).validate('kevin') + Ascii(min_length=0).validate("kevin") - Ascii(min_length=1).validate('k') + Ascii(min_length=1).validate("k") - Ascii(min_length=5).validate('kevin') - Ascii(min_length=5).validate('kevintastic') + Ascii(min_length=5).validate("kevin") + Ascii(min_length=5).validate("kevintastic") with pytest.raises(ValidationError): - Ascii(min_length=1).validate('') + Ascii(min_length=1).validate("") with pytest.raises(ValidationError): Ascii(min_length=1).validate(None) with pytest.raises(ValidationError): - Ascii(min_length=6).validate('') + Ascii(min_length=6).validate("") with pytest.raises(ValidationError): Ascii(min_length=6).validate(None) with pytest.raises(ValidationError): - Ascii(min_length=6).validate('kevin') + Ascii(min_length=6).validate("kevin") with pytest.raises(ValueError): Ascii(min_length=-1) def test_max_length(self): - """ Test arbitrary maximal lengths requirements. """ - Ascii(max_length=0).validate('') + """Test arbitrary maximal lengths requirements.""" + Ascii(max_length=0).validate("") Ascii(max_length=0).validate(None) - Ascii(max_length=1).validate('') + Ascii(max_length=1).validate("") Ascii(max_length=1).validate(None) - Ascii(max_length=1).validate('b') + Ascii(max_length=1).validate("b") - Ascii(max_length=5).validate('') + Ascii(max_length=5).validate("") Ascii(max_length=5).validate(None) - Ascii(max_length=5).validate('b') - Ascii(max_length=5).validate('blake') + Ascii(max_length=5).validate("b") + Ascii(max_length=5).validate("blake") with pytest.raises(ValidationError): - Ascii(max_length=0).validate('b') + Ascii(max_length=0).validate("b") with pytest.raises(ValidationError): - Ascii(max_length=5).validate('blaketastic') + Ascii(max_length=5).validate("blaketastic") with pytest.raises(ValueError): Ascii(max_length=-1) @@ -596,9 +627,9 @@ def test_length_range(self): Ascii(min_length=1, max_length=0) def test_type_checking(self): - Ascii().validate('string') - Ascii().validate(u'unicode') - Ascii().validate(bytearray('bytearray', encoding='ascii')) + Ascii().validate("string") + Ascii().validate("unicode") + Ascii().validate(bytearray("bytearray", encoding="ascii")) with pytest.raises(ValidationError): Ascii().validate(5) @@ -606,106 +637,101 @@ def test_type_checking(self): with pytest.raises(ValidationError): Ascii().validate(True) - Ascii().validate("!#$%&\'()*+,-./") + Ascii().validate("!#$%&'()*+,-./") with pytest.raises(ValidationError): - Ascii().validate('Beyonc' + chr(233)) - - if sys.version_info < (3, 1): - with pytest.raises(ValidationError): - Ascii().validate(u'Beyonc' + unichr(233)) + Ascii().validate("Beyonc" + chr(233)) def test_unaltering_validation(self): - """ Test the validation step doesn't re-interpret values. """ - assert Ascii().validate('') == '' + """Test the validation step doesn't re-interpret values.""" + assert Ascii().validate("") == "" assert Ascii().validate(None) == None - assert Ascii().validate('yo') == 'yo' + assert Ascii().validate("yo") == "yo" def test_non_required_validation(self): - """ Tests that validation is ok on none and blank values if required is False. """ - Ascii().validate('') + """Tests that validation is ok on none and blank values if required is False.""" + Ascii().validate("") Ascii().validate(None) def test_required_validation(self): - """ Tests that validation raise on none and blank values if value required. """ - Ascii(required=True).validate('k') + """Tests that validation raise on none and blank values if value required.""" + Ascii(required=True).validate("k") with pytest.raises(ValidationError): - Ascii(required=True).validate('') + Ascii(required=True).validate("") with pytest.raises(ValidationError): Ascii(required=True).validate(None) # With min_length set. - Ascii(required=True, min_length=0).validate('k') - Ascii(required=True, min_length=1).validate('k') + Ascii(required=True, min_length=0).validate("k") + Ascii(required=True, min_length=1).validate("k") with pytest.raises(ValidationError): - Ascii(required=True, min_length=2).validate('k') + Ascii(required=True, min_length=2).validate("k") # With max_length set. - Ascii(required=True, max_length=1).validate('k') + Ascii(required=True, max_length=1).validate("k") with pytest.raises(ValidationError): - Ascii(required=True, max_length=2).validate('kevin') + Ascii(required=True, max_length=2).validate("kevin") with pytest.raises(ValueError): Ascii(required=True, max_length=0) class TestText(BaseCassEngTestCase): - def test_min_length(self): - """ Test arbitrary minimal lengths requirements. """ + """Test arbitrary minimal lengths requirements.""" - Text(min_length=0).validate('') - Text(min_length=0, required=True).validate('') + Text(min_length=0).validate("") + Text(min_length=0, required=True).validate("") Text(min_length=0).validate(None) - Text(min_length=0).validate('blake') + Text(min_length=0).validate("blake") - Text(min_length=1).validate('b') + Text(min_length=1).validate("b") - Text(min_length=5).validate('blake') - Text(min_length=5).validate('blaketastic') + Text(min_length=5).validate("blake") + Text(min_length=5).validate("blaketastic") with pytest.raises(ValidationError): - Text(min_length=1).validate('') + Text(min_length=1).validate("") with pytest.raises(ValidationError): Text(min_length=1).validate(None) with pytest.raises(ValidationError): - Text(min_length=6).validate('') + Text(min_length=6).validate("") with pytest.raises(ValidationError): Text(min_length=6).validate(None) with pytest.raises(ValidationError): - Text(min_length=6).validate('blake') + Text(min_length=6).validate("blake") with pytest.raises(ValueError): Text(min_length=-1) def test_max_length(self): - """ Test arbitrary maximal lengths requirements. """ - Text(max_length=0).validate('') + """Test arbitrary maximal lengths requirements.""" + Text(max_length=0).validate("") Text(max_length=0).validate(None) - Text(max_length=1).validate('') + Text(max_length=1).validate("") Text(max_length=1).validate(None) - Text(max_length=1).validate('b') + Text(max_length=1).validate("b") - Text(max_length=5).validate('') + Text(max_length=5).validate("") Text(max_length=5).validate(None) - Text(max_length=5).validate('b') - Text(max_length=5).validate('blake') + Text(max_length=5).validate("b") + Text(max_length=5).validate("blake") with pytest.raises(ValidationError): - Text(max_length=0).validate('b') + Text(max_length=0).validate("b") with pytest.raises(ValidationError): - Text(max_length=5).validate('blaketastic') + Text(max_length=5).validate("blaketastic") with pytest.raises(ValueError): Text(max_length=-1) @@ -723,9 +749,9 @@ def test_length_range(self): Text(min_length=1, max_length=0) def test_type_checking(self): - Text().validate('string') - Text().validate(u'unicode') - Text().validate(bytearray('bytearray', encoding='ascii')) + Text().validate("string") + Text().validate("unicode") + Text().validate(bytearray("bytearray", encoding="ascii")) with pytest.raises(ValidationError): Text().validate(5) @@ -733,44 +759,42 @@ def test_type_checking(self): with pytest.raises(ValidationError): Text().validate(True) - Text().validate("!#$%&\'()*+,-./") - Text().validate('Beyonc' + chr(233)) - if sys.version_info < (3, 1): - Text().validate(u'Beyonc' + unichr(233)) + Text().validate("!#$%&'()*+,-./") + Text().validate("Beyonc" + chr(233)) def test_unaltering_validation(self): - """ Test the validation step doesn't re-interpret values. """ - assert Text().validate('') == '' + """Test the validation step doesn't re-interpret values.""" + assert Text().validate("") == "" assert Text().validate(None) == None - assert Text().validate('yo') == 'yo' + assert Text().validate("yo") == "yo" def test_non_required_validation(self): - """ Tests that validation is ok on none and blank values if required is False """ - Text().validate('') + """Tests that validation is ok on none and blank values if required is False""" + Text().validate("") Text().validate(None) def test_required_validation(self): - """ Tests that validation raise on none and blank values if value required. """ - Text(required=True).validate('b') + """Tests that validation raise on none and blank values if value required.""" + Text(required=True).validate("b") with pytest.raises(ValidationError): - Text(required=True).validate('') + Text(required=True).validate("") with pytest.raises(ValidationError): Text(required=True).validate(None) # With min_length set. - Text(required=True, min_length=0).validate('b') - Text(required=True, min_length=1).validate('b') + Text(required=True, min_length=0).validate("b") + Text(required=True, min_length=1).validate("b") with pytest.raises(ValidationError): - Text(required=True, min_length=2).validate('b') + Text(required=True, min_length=2).validate("b") # With max_length set. - Text(required=True, max_length=1).validate('b') + Text(required=True, max_length=1).validate("b") with pytest.raises(ValidationError): - Text(required=True, max_length=2).validate('blake') + Text(required=True, max_length=2).validate("blake") with pytest.raises(ValueError): Text(required=True, max_length=0) @@ -778,7 +802,6 @@ def test_required_validation(self): class TestExtraFieldsRaiseException(BaseCassEngTestCase): class TestModel(Model): - id = UUID(primary_key=True, default=uuid4) def test_extra_field(self): @@ -788,15 +811,18 @@ def test_extra_field(self): class TestPythonDoesntDieWhenExtraFieldIsInCassandra(BaseCassEngTestCase): class TestModel(Model): - - __table_name__ = 'alter_doesnt_break_running_app' + __table_name__ = "alter_doesnt_break_running_app" id = UUID(primary_key=True, default=uuid4) def test_extra_field(self): drop_table(self.TestModel) sync_table(self.TestModel) self.TestModel.create() - execute("ALTER TABLE {0} add blah int".format(self.TestModel.column_family_name(include_keyspace=True))) + execute( + "ALTER TABLE {0} add blah int".format( + self.TestModel.column_family_name(include_keyspace=True) + ) + ) self.TestModel.objects.all() @@ -807,9 +833,10 @@ def test_conversion_specific_date(self): uuid = util.uuid_from_time(dt) from uuid import UUID + assert isinstance(uuid, UUID) - ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp + ts = (uuid.time - 0x01B21DD213814000) / 1e7 # back to a timestamp new_dt = datetime.fromtimestamp(ts, tz=timezone.utc).replace(tzinfo=None) # checks that we created a UUID1 with the proper timestamp @@ -817,7 +844,6 @@ def test_conversion_specific_date(self): class TestInet(BaseCassEngTestCase): - class InetTestModel(Model): id = UUID(primary_key=True, default=uuid4) address = Inet() From a8686f933a63dfa5481d823f3be79cc39c9adf70 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:17:03 +0200 Subject: [PATCH 15/18] Remove redundant Python version checks that are always True/False on 3.9+ Three version guards were left over from Python 2/3.x compatibility: 1. cluster.py: 'if sys.version_info[0] >= 3 and sys.version_info[1] >= 7' guarded the Eventlet/futurist ThreadPoolExecutor workaround. Since the driver requires 3.9+, this is always True. Removed the guard, dedented the body, and updated the docstring and error message to drop the 'Python 3.7+' qualifier (the issue is inherent to Eventlet, not a version-specific regression). 2. test_row_factories.py: NAMEDTUPLE_CREATION_BUG was defined as 'sys.version_info >= (3,) and sys.version_info < (3, 7)', which is always False on 3.9+. The test's dead branch tested a warning path that can never trigger. Removed the constant, the dead branch, the unused 'sys' import, and simplified the test to just verify long column lists work. 3. test_insights.py: 'if sys.version_info > (3,)' guarded a namespace suffix that is always needed on Python 3. Removed the guard and the now-unused 'sys' import. All 608 unit tests pass. --- cassandra/cluster.py | 47 ++-- tests/unit/advanced/test_insights.py | 378 +++++++++++++++++---------- tests/unit/test_row_factories.py | 41 +-- 3 files changed, 277 insertions(+), 189 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index af8960504f..10fa0ef7b7 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1735,39 +1735,38 @@ def _create_thread_pool_executor(self, **kwargs): Create a ThreadPoolExecutor for the cluster. In most cases, the built-in `concurrent.futures.ThreadPoolExecutor` is used. - Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` - to hang indefinitely. In that case, the user needs to have the `futurist` + Eventlet causes the `concurrent.futures.ThreadPoolExecutor` to hang + indefinitely. In that case, the user needs to have the `futurist` package so we can use the `futurist.GreenThreadPoolExecutor` class instead. :param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor. :return: A ThreadPoolExecutor instance. """ tpe_class = ThreadPoolExecutor - if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: - try: - from cassandra.io.eventletreactor import EventletConnection + try: + from cassandra.io.eventletreactor import EventletConnection - is_eventlet = issubclass(self.connection_class, EventletConnection) - except: - # Eventlet is not available or can't be detected - return tpe_class(**kwargs) + is_eventlet = issubclass(self.connection_class, EventletConnection) + except: + # Eventlet is not available or can't be detected + return tpe_class(**kwargs) - if is_eventlet: - try: - from futurist import GreenThreadPoolExecutor - - tpe_class = GreenThreadPoolExecutor - except ImportError: - # futurist is not available - raise ImportError( - ( - "Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " - "to hang indefinitely. If you want to use the Eventlet reactor, you " - "need to install the `futurist` package to allow the driver to use " - "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " - "for more details." - ) + if is_eventlet: + try: + from futurist import GreenThreadPoolExecutor + + tpe_class = GreenThreadPoolExecutor + except ImportError: + # futurist is not available + raise ImportError( + ( + "Eventlet causes the `concurrent.futures.ThreadPoolExecutor` " + "to hang indefinitely. If you want to use the Eventlet reactor, you " + "need to install the `futurist` package to allow the driver to use " + "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " + "for more details." ) + ) return tpe_class(**kwargs) diff --git a/tests/unit/advanced/test_insights.py b/tests/unit/advanced/test_insights.py index ec9b918866..24a766b46f 100644 --- a/tests/unit/advanced/test_insights.py +++ b/tests/unit/advanced/test_insights.py @@ -16,16 +16,18 @@ import unittest import logging -import sys from unittest.mock import sentinel from cassandra import ConsistencyLevel from cassandra.cluster import ( - ExecutionProfile, GraphExecutionProfile, ProfileManager, + ExecutionProfile, + GraphExecutionProfile, + ProfileManager, GraphAnalyticsExecutionProfile, - EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_DEFAULT, + EXEC_PROFILE_GRAPH_DEFAULT, EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, - EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, ) from cassandra.datastax.graph.query import GraphOptions from cassandra.datastax.insights.registry import insights_registry @@ -43,7 +45,7 @@ RetryPolicy, SpeculativeExecutionPolicy, ConstantSpeculativeExecutionPolicy, - WrapperPolicy + WrapperPolicy, ) @@ -53,28 +55,32 @@ class TestGetConfig(unittest.TestCase): - def test_invalid_object(self): class NoConfAsDict(object): pass obj = NoConfAsDict() - ns = 'tests.unit.advanced.test_insights' - if sys.version_info > (3,): - ns += '.TestGetConfig.test_invalid_object.' + ns = "tests.unit.advanced.test_insights" + ns += ".TestGetConfig.test_invalid_object." # no default # ... as a policy - assert insights_registry.serialize(obj, policy=True) == {'type': 'NoConfAsDict', - 'namespace': ns, - 'options': {}} + assert insights_registry.serialize(obj, policy=True) == { + "type": "NoConfAsDict", + "namespace": ns, + "options": {}, + } # ... not as a policy (default) - assert insights_registry.serialize(obj) == {'type': 'NoConfAsDict', - 'namespace': ns, - } + assert insights_registry.serialize(obj) == { + "type": "NoConfAsDict", + "namespace": ns, + } # with default - assert insights_registry.serialize(obj, default=sentinel.attr_err_default) is sentinel.attr_err_default + assert ( + insights_registry.serialize(obj, default=sentinel.attr_err_default) + is sentinel.attr_err_default + ) def test_successful_return(self): @@ -88,167 +94,269 @@ class SubclassSentinel(SuperclassSentinel): def superclass_sentinel_serializer(obj): return sentinel.serialized_superclass - assert insights_registry.serialize(SuperclassSentinel()) is sentinel.serialized_superclass - assert insights_registry.serialize(SubclassSentinel()) is sentinel.serialized_superclass + assert ( + insights_registry.serialize(SuperclassSentinel()) + is sentinel.serialized_superclass + ) + assert ( + insights_registry.serialize(SubclassSentinel()) + is sentinel.serialized_superclass + ) # with default -- same behavior - assert insights_registry.serialize(SubclassSentinel(), default=object()) is sentinel.serialized_superclass + assert ( + insights_registry.serialize(SubclassSentinel(), default=object()) + is sentinel.serialized_superclass + ) -class TestConfigAsDict(unittest.TestCase): +class TestConfigAsDict(unittest.TestCase): # graph/query.py def test_graph_options(self): self.maxDiff = None - go = GraphOptions(graph_name='name_for_test', - graph_source='source_for_test', - graph_language='lang_for_test', - graph_protocol='protocol_for_test', - graph_read_consistency_level=ConsistencyLevel.ANY, - graph_write_consistency_level=ConsistencyLevel.ONE, - graph_invalid_option='invalid') + go = GraphOptions( + graph_name="name_for_test", + graph_source="source_for_test", + graph_language="lang_for_test", + graph_protocol="protocol_for_test", + graph_read_consistency_level=ConsistencyLevel.ANY, + graph_write_consistency_level=ConsistencyLevel.ONE, + graph_invalid_option="invalid", + ) log.debug(go._graph_options) - assert insights_registry.serialize(go) == {'source': 'source_for_test', - 'language': 'lang_for_test', - 'graphProtocol': 'protocol_for_test', - # no graph_invalid_option - } + assert insights_registry.serialize(go) == { + "source": "source_for_test", + "language": "lang_for_test", + "graphProtocol": "protocol_for_test", + # no graph_invalid_option + } # cluster.py def test_execution_profile(self): self.maxDiff = None - assert insights_registry.serialize(ExecutionProfile()) == {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': True}, - 'type': 'TokenAwarePolicy'}, - 'readTimeout': 10.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'RetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': None - } + assert insights_registry.serialize(ExecutionProfile()) == { + "consistency": "LOCAL_ONE", + "continuousPagingOptions": None, + "loadBalancing": { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": {"local_dc": "", "used_hosts_per_remote_dc": 0}, + "type": "DCAwareRoundRobinPolicy", + }, + "shuffle_replicas": True, + }, + "type": "TokenAwarePolicy", + }, + "readTimeout": 10.0, + "retry": { + "namespace": "cassandra.policies", + "options": {}, + "type": "RetryPolicy", + }, + "serialConsistency": None, + "speculativeExecution": { + "namespace": "cassandra.policies", + "options": {}, + "type": "NoSpeculativeExecutionPolicy", + }, + "graphOptions": None, + } def test_graph_execution_profile(self): self.maxDiff = None - assert insights_registry.serialize(GraphExecutionProfile()) == {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': True}, - 'type': 'TokenAwarePolicy'}, - 'readTimeout': 30.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': {'graphProtocol': None, - 'language': 'gremlin-groovy', - 'source': 'g'}, - } + assert insights_registry.serialize(GraphExecutionProfile()) == { + "consistency": "LOCAL_ONE", + "continuousPagingOptions": None, + "loadBalancing": { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": {"local_dc": "", "used_hosts_per_remote_dc": 0}, + "type": "DCAwareRoundRobinPolicy", + }, + "shuffle_replicas": True, + }, + "type": "TokenAwarePolicy", + }, + "readTimeout": 30.0, + "retry": { + "namespace": "cassandra.policies", + "options": {}, + "type": "NeverRetryPolicy", + }, + "serialConsistency": None, + "speculativeExecution": { + "namespace": "cassandra.policies", + "options": {}, + "type": "NoSpeculativeExecutionPolicy", + }, + "graphOptions": { + "graphProtocol": None, + "language": "gremlin-groovy", + "source": "g", + }, + } def test_graph_analytics_execution_profile(self): self.maxDiff = None - assert insights_registry.serialize(GraphAnalyticsExecutionProfile()) == {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': True}, - 'type': 'TokenAwarePolicy'}}, - 'type': 'DefaultLoadBalancingPolicy'}, - 'readTimeout': 604800.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': {'graphProtocol': None, - 'language': 'gremlin-groovy', - 'source': 'a'}, - } + assert insights_registry.serialize(GraphAnalyticsExecutionProfile()) == { + "consistency": "LOCAL_ONE", + "continuousPagingOptions": None, + "loadBalancing": { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": { + "local_dc": "", + "used_hosts_per_remote_dc": 0, + }, + "type": "DCAwareRoundRobinPolicy", + }, + "shuffle_replicas": True, + }, + "type": "TokenAwarePolicy", + } + }, + "type": "DefaultLoadBalancingPolicy", + }, + "readTimeout": 604800.0, + "retry": { + "namespace": "cassandra.policies", + "options": {}, + "type": "NeverRetryPolicy", + }, + "serialConsistency": None, + "speculativeExecution": { + "namespace": "cassandra.policies", + "options": {}, + "type": "NoSpeculativeExecutionPolicy", + }, + "graphOptions": { + "graphProtocol": None, + "language": "gremlin-groovy", + "source": "a", + }, + } # policies.py def test_DC_aware_round_robin_policy(self): - assert insights_registry.serialize(DCAwareRoundRobinPolicy()) == {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'} - assert insights_registry.serialize(DCAwareRoundRobinPolicy(local_dc='fake_local_dc', - used_hosts_per_remote_dc=15)) == {'namespace': 'cassandra.policies', - 'options': {'local_dc': 'fake_local_dc', 'used_hosts_per_remote_dc': 15}, - 'type': 'DCAwareRoundRobinPolicy'} + assert insights_registry.serialize(DCAwareRoundRobinPolicy()) == { + "namespace": "cassandra.policies", + "options": {"local_dc": "", "used_hosts_per_remote_dc": 0}, + "type": "DCAwareRoundRobinPolicy", + } + assert insights_registry.serialize( + DCAwareRoundRobinPolicy( + local_dc="fake_local_dc", used_hosts_per_remote_dc=15 + ) + ) == { + "namespace": "cassandra.policies", + "options": {"local_dc": "fake_local_dc", "used_hosts_per_remote_dc": 15}, + "type": "DCAwareRoundRobinPolicy", + } def test_token_aware_policy(self): - assert insights_registry.serialize(TokenAwarePolicy(child_policy=LoadBalancingPolicy())) == {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'}, - 'shuffle_replicas': True}, - 'type': 'TokenAwarePolicy'} + assert insights_registry.serialize( + TokenAwarePolicy(child_policy=LoadBalancingPolicy()) + ) == { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": {}, + "type": "LoadBalancingPolicy", + }, + "shuffle_replicas": True, + }, + "type": "TokenAwarePolicy", + } def test_whitelist_round_robin_policy(self): - assert insights_registry.serialize(WhiteListRoundRobinPolicy(['127.0.0.3'])) == {'namespace': 'cassandra.policies', - 'options': {'allowed_hosts': ('127.0.0.3',)}, - 'type': 'WhiteListRoundRobinPolicy'} + assert insights_registry.serialize( + WhiteListRoundRobinPolicy(["127.0.0.3"]) + ) == { + "namespace": "cassandra.policies", + "options": {"allowed_hosts": ("127.0.0.3",)}, + "type": "WhiteListRoundRobinPolicy", + } def test_host_filter_policy(self): def my_predicate(s): return False - assert insights_registry.serialize(HostFilterPolicy(LoadBalancingPolicy(), my_predicate)) == {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'}, - 'predicate': 'my_predicate'}, - 'type': 'HostFilterPolicy'} + assert insights_registry.serialize( + HostFilterPolicy(LoadBalancingPolicy(), my_predicate) + ) == { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": {}, + "type": "LoadBalancingPolicy", + }, + "predicate": "my_predicate", + }, + "type": "HostFilterPolicy", + } def test_constant_reconnection_policy(self): - assert insights_registry.serialize(ConstantReconnectionPolicy(3, 200)) == {'type': 'ConstantReconnectionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'delay': 3, 'max_attempts': 200} - } + assert insights_registry.serialize(ConstantReconnectionPolicy(3, 200)) == { + "type": "ConstantReconnectionPolicy", + "namespace": "cassandra.policies", + "options": {"delay": 3, "max_attempts": 200}, + } def test_exponential_reconnection_policy(self): - assert insights_registry.serialize(ExponentialReconnectionPolicy(4, 100, 10)) == {'type': 'ExponentialReconnectionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10} - } + assert insights_registry.serialize( + ExponentialReconnectionPolicy(4, 100, 10) + ) == { + "type": "ExponentialReconnectionPolicy", + "namespace": "cassandra.policies", + "options": {"base_delay": 4, "max_delay": 100, "max_attempts": 10}, + } def test_retry_policy(self): - assert insights_registry.serialize(RetryPolicy()) == {'type': 'RetryPolicy', - 'namespace': 'cassandra.policies', - 'options': {} - } + assert insights_registry.serialize(RetryPolicy()) == { + "type": "RetryPolicy", + "namespace": "cassandra.policies", + "options": {}, + } def test_spec_exec_policy(self): - assert insights_registry.serialize(SpeculativeExecutionPolicy()) == {'type': 'SpeculativeExecutionPolicy', - 'namespace': 'cassandra.policies', - 'options': {} - } + assert insights_registry.serialize(SpeculativeExecutionPolicy()) == { + "type": "SpeculativeExecutionPolicy", + "namespace": "cassandra.policies", + "options": {}, + } def test_constant_spec_exec_policy(self): - assert insights_registry.serialize(ConstantSpeculativeExecutionPolicy(100, 101)) == {'type': 'ConstantSpeculativeExecutionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'delay': 100, - 'max_attempts': 101} - } + assert insights_registry.serialize( + ConstantSpeculativeExecutionPolicy(100, 101) + ) == { + "type": "ConstantSpeculativeExecutionPolicy", + "namespace": "cassandra.policies", + "options": {"delay": 100, "max_attempts": 101}, + } def test_wrapper_policy(self): - assert insights_registry.serialize(WrapperPolicy(LoadBalancingPolicy())) == {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'} - }, - 'type': 'WrapperPolicy'} + assert insights_registry.serialize(WrapperPolicy(LoadBalancingPolicy())) == { + "namespace": "cassandra.policies", + "options": { + "child_policy": { + "namespace": "cassandra.policies", + "options": {}, + "type": "LoadBalancingPolicy", + } + }, + "type": "WrapperPolicy", + } diff --git a/tests/unit/test_row_factories.py b/tests/unit/test_row_factories.py index 7787f1d271..d0fe140498 100644 --- a/tests/unit/test_row_factories.py +++ b/tests/unit/test_row_factories.py @@ -18,56 +18,37 @@ import logging import warnings -import sys - from unittest import TestCase log = logging.getLogger(__name__) -NAMEDTUPLE_CREATION_BUG = sys.version_info >= (3,) and sys.version_info < (3, 7) - class TestNamedTupleFactory(TestCase): - long_colnames, long_rows = ( - ['col{}'.format(x) for x in range(300)], - [ - ['value{}'.format(x) for x in range(300)] - for _ in range(100) - ] + ["col{}".format(x) for x in range(300)], + [["value{}".format(x) for x in range(300)] for _ in range(100)], ) short_colnames, short_rows = ( - ['col{}'.format(x) for x in range(200)], - [ - ['value{}'.format(x) for x in range(200)] - for _ in range(100) - ] + ["col{}".format(x) for x in range(200)], + [["value{}".format(x) for x in range(200)] for _ in range(100)], ) - def test_creation_warning_on_long_column_list(self): + def test_creation_on_long_column_list(self): """ - Reproduces the failure described in PYTHON-893 + Verifies that named_tuple_factory handles long column lists. @since 3.15 @jira_ticket PYTHON-893 - @expected_result creation fails on Python > 3 and < 3.7 + @expected_result creation succeeds (the bug only affected Python 3.0-3.6) @test_category row_factory """ - if not NAMEDTUPLE_CREATION_BUG: - named_tuple_factory(self.long_colnames, self.long_rows) - return - with warnings.catch_warnings(record=True) as w: rows = named_tuple_factory(self.long_colnames, self.long_rows) - assert len(w) == 1 - warning = w[0] - assert 'pseudo_namedtuple_factory' in str(warning) - assert '3.7' in str(warning) - - for r in rows: - assert r.col0 == self.long_rows[0][0] + assert len(w) == 0 + assert hasattr(rows[0], "_fields") + assert isinstance(rows[0], tuple) def test_creation_no_warning_on_short_column_list(self): """ @@ -83,5 +64,5 @@ def test_creation_no_warning_on_short_column_list(self): rows = named_tuple_factory(self.short_colnames, self.short_rows) assert len(w) == 0 # check that this is a real namedtuple - assert hasattr(rows[0], '_fields') + assert hasattr(rows[0], "_fields") assert isinstance(rows[0], tuple) From 54bcb9df8b3b5890b06f49329d7f1a11ce76a1c4 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:20:11 +0200 Subject: [PATCH 16/18] Remove u'' string prefixes from production code On Python 3, the u'' prefix is a no-op since all strings are already unicode. These prefixes were left over from Python 2 compatibility and add visual noise without any semantic effect. Removed u'' prefixes from: - cassandra/query.py: __str__ methods for SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, and a docstring example showing OrderedMapSerializedKey output - cassandra/datastax/graph/query.py: GraphStatement.__str__ - cassandra/datastax/graph/fluent/_query.py: TraversalBatch.__str__ and as_graph_statement query construction - docs/conf.py: project name and copyright strings All 608 unit tests pass. --- cassandra/datastax/graph/fluent/_query.py | 81 ++-- cassandra/datastax/graph/query.py | 129 ++++-- cassandra/query.py | 464 ++++++++++++++++------ docs/conf.py | 133 ++++--- 4 files changed, 556 insertions(+), 251 deletions(-) diff --git a/cassandra/datastax/graph/fluent/_query.py b/cassandra/datastax/graph/fluent/_query.py index d5eb7f6373..1e39bd7931 100644 --- a/cassandra/datastax/graph/fluent/_query.py +++ b/cassandra/datastax/graph/fluent/_query.py @@ -21,17 +21,19 @@ from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2 from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3 -from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \ - dse_graphson2_serializers, dse_graphson3_serializers +from cassandra.datastax.graph.fluent.serializers import ( + GremlinUserTypeIO, + dse_graphson2_serializers, + dse_graphson3_serializers, +) log = logging.getLogger(__name__) -__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch'] +__all__ = ["TraversalBatch", "_query_from_traversal", "_DefaultTraversalBatch"] class _GremlinGraphSONWriterAdapter(object): - def __init__(self, context, **kwargs): super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs) self.context = context @@ -53,14 +55,20 @@ def get_serializer(self, value): # Check if UDT if self.user_types is None: try: - user_types = self.context['cluster']._user_types[self.context['graph_name']] + user_types = self.context["cluster"]._user_types[ + self.context["graph_name"] + ] self.user_types = dict(map(reversed, user_types.items())) except KeyError: self.user_types = {} # Custom detection to map a namedtuple to udt - if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or - (not serializer and type(value) in self.user_types)): + if ( + tuple in self.serializers + and serializer is self.serializers[tuple] + and hasattr(value, "_fields") + or (not serializer and type(value) in self.user_types) + ): serializer = GremlinUserTypeIO if serializer: @@ -101,13 +109,17 @@ def _query_from_traversal(traversal, graph_protocol, context=None): :param graphson_protocol: The graph protocol to determine the output format. """ if graph_protocol == GraphProtocol.GRAPHSON_2_0: - graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers) + graphson_writer = graphson2_writer( + context, serializer_map=dse_graphson2_serializers + ) elif graph_protocol == GraphProtocol.GRAPHSON_3_0: if context is None: - raise ValueError('Missing context for GraphSON3 serialization requires.') - graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers) + raise ValueError("Missing context for GraphSON3 serialization requires.") + graphson_writer = graphson3_writer( + context, serializer_map=dse_graphson3_serializers + ) else: - raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) + raise ValueError("Unknown graph protocol: {}".format(graph_protocol)) try: query = graphson_writer.writeObject(traversal) @@ -179,12 +191,12 @@ def __len__(self): raise NotImplementedError() def __str__(self): - return u''.format(len(self)) + return "".format(len(self)) + __repr__ = __str__ class _DefaultTraversalBatch(TraversalBatch): - _traversals = None def __init__(self, *args, **kwargs): @@ -193,7 +205,7 @@ def __init__(self, *args, **kwargs): def add(self, traversal): if not isinstance(traversal, GraphTraversal): - raise ValueError('traversal should be a gremlin GraphTraversal') + raise ValueError("traversal should be a gremlin GraphTraversal") self._traversals.append(traversal) return self @@ -202,24 +214,41 @@ def add_all(self, traversals): for traversal in traversals: self.add(traversal) - def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None): - statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals] - query = u"[{0}]".format(','.join(statements)) + def as_graph_statement( + self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None + ): + statements = [ + _query_from_traversal(t, graph_protocol, context) for t in self._traversals + ] + query = "[{0}]".format(",".join(statements)) return SimpleGraphStatement(query) def execute(self): if self._session is None: - raise ValueError('A DSE Session must be provided to execute the traversal batch.') - - execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT - graph_options = self._session.get_execution_profile(execution_profile).graph_options + raise ValueError( + "A DSE Session must be provided to execute the traversal batch." + ) + + execution_profile = ( + self._execution_profile + if self._execution_profile + else EXEC_PROFILE_GRAPH_DEFAULT + ) + graph_options = self._session.get_execution_profile( + execution_profile + ).graph_options context = { - 'cluster': self._session.cluster, - 'graph_name': graph_options.graph_name + "cluster": self._session.cluster, + "graph_name": graph_options.graph_name, } - statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \ - if graph_options.graph_protocol else self.as_graph_statement(context=context) - return self._session.execute_graph(statement, execution_profile=execution_profile) + statement = ( + self.as_graph_statement(graph_options.graph_protocol, context=context) + if graph_options.graph_protocol + else self.as_graph_statement(context=context) + ) + return self._session.execute_graph( + statement, execution_profile=execution_profile + ) def clear(self): del self._traversals[:] diff --git a/cassandra/datastax/graph/query.py b/cassandra/datastax/graph/query.py index 866df7a94c..67b7b426c3 100644 --- a/cassandra/datastax/graph/query.py +++ b/cassandra/datastax/graph/query.py @@ -22,41 +22,67 @@ __all__ = [ - 'GraphProtocol', 'GraphOptions', 'GraphStatement', 'SimpleGraphStatement', - 'single_object_row_factory', 'graph_result_row_factory', 'graph_object_row_factory', - 'graph_graphson2_row_factory', 'Result', 'graph_graphson3_row_factory' + "GraphProtocol", + "GraphOptions", + "GraphStatement", + "SimpleGraphStatement", + "single_object_row_factory", + "graph_result_row_factory", + "graph_object_row_factory", + "graph_graphson2_row_factory", + "Result", + "graph_graphson3_row_factory", ] # (attr, description, server option) _graph_options = ( - ('graph_name', 'name of the targeted graph.', 'graph-name'), - ('graph_source', 'choose the graph traversal source, configured on the server side.', 'graph-source'), - ('graph_language', 'the language used in the queries (default "gremlin-groovy")', 'graph-language'), - ('graph_protocol', 'the graph protocol that the server should use for query results (default "graphson-1-0")', 'graph-results'), - ('graph_read_consistency_level', '''read `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). -Setting this overrides the native `Statement.consistency_level `_ for read operations from Cassandra persistence''', 'graph-read-consistency'), - ('graph_write_consistency_level', '''write `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). -Setting this overrides the native `Statement.consistency_level `_ for write operations to Cassandra persistence.''', 'graph-write-consistency') + ("graph_name", "name of the targeted graph.", "graph-name"), + ( + "graph_source", + "choose the graph traversal source, configured on the server side.", + "graph-source", + ), + ( + "graph_language", + 'the language used in the queries (default "gremlin-groovy")', + "graph-language", + ), + ( + "graph_protocol", + 'the graph protocol that the server should use for query results (default "graphson-1-0")', + "graph-results", + ), + ( + "graph_read_consistency_level", + """read `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for read operations from Cassandra persistence""", + "graph-read-consistency", + ), + ( + "graph_write_consistency_level", + """write `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for write operations to Cassandra persistence.""", + "graph-write-consistency", + ), ) _graph_option_names = tuple(option[0] for option in _graph_options) # this is defined by the execution profile attribute, not in graph options -_request_timeout_key = 'request-timeout' +_request_timeout_key = "request-timeout" class GraphProtocol(object): - - GRAPHSON_1_0 = b'graphson-1.0' + GRAPHSON_1_0 = b"graphson-1.0" """ GraphSON1 """ - GRAPHSON_2_0 = b'graphson-2.0' + GRAPHSON_2_0 = b"graphson-2.0" """ GraphSON2 """ - GRAPHSON_3_0 = b'graphson-3.0' + GRAPHSON_3_0 = b"graphson-3.0" """ GraphSON3 """ @@ -66,18 +92,23 @@ class GraphOptions(object): """ Options for DSE Graph Query handler. """ + # See _graph_options map above for notes on valid options DEFAULT_GRAPH_PROTOCOL = GraphProtocol.GRAPHSON_1_0 - DEFAULT_GRAPH_LANGUAGE = b'gremlin-groovy' + DEFAULT_GRAPH_LANGUAGE = b"gremlin-groovy" def __init__(self, **kwargs): self._graph_options = {} - kwargs.setdefault('graph_source', 'g') - kwargs.setdefault('graph_language', GraphOptions.DEFAULT_GRAPH_LANGUAGE) + kwargs.setdefault("graph_source", "g") + kwargs.setdefault("graph_language", GraphOptions.DEFAULT_GRAPH_LANGUAGE) for attr, value in kwargs.items(): if attr not in _graph_option_names: - warn("Unknown keyword argument received for GraphOptions: {0}".format(attr)) + warn( + "Unknown keyword argument received for GraphOptions: {0}".format( + attr + ) + ) setattr(self, attr, value) def copy(self): @@ -98,7 +129,7 @@ def get_options_map(self, other_options=None): options.update(other_options._graph_options) # cls are special-cased so they can be enums in the API, and names in the protocol - for cl in ('graph-write-consistency', 'graph-read-consistency'): + for cl in ("graph-write-consistency", "graph-read-consistency"): cl_enum = options.get(cl) if cl_enum is not None: options[cl] = ConsistencyLevel.value_to_name[cl_enum].encode() @@ -108,19 +139,19 @@ def set_source_default(self): """ Sets ``graph_source`` to the server-defined default traversal source ('default') """ - self.graph_source = 'default' + self.graph_source = "default" def set_source_analytics(self): """ Sets ``graph_source`` to the server-defined analytic traversal source ('a') """ - self.graph_source = 'a' + self.graph_source = "a" def set_source_graph(self): """ Sets ``graph_source`` to the server-defined graph traversal source ('g') """ - self.graph_source = 'g' + self.graph_source = "g" def set_graph_protocol(self, protocol): """ @@ -130,21 +161,21 @@ def set_graph_protocol(self, protocol): @property def is_default_source(self): - return self.graph_source in (b'default', None) + return self.graph_source in (b"default", None) @property def is_analytics_source(self): """ True if ``graph_source`` is set to the server-defined analytics traversal source ('a') """ - return self.graph_source == b'a' + return self.graph_source == b"a" @property def is_graph_source(self): """ True if ``graph_source`` is set to the server-defined graph traversal source ('g') """ - return self.graph_source == b'g' + return self.graph_source == b"g" for opt in _graph_options: @@ -168,14 +199,15 @@ def delete(self, key=opt[2]): class GraphStatement(Statement): - """ An abstract class representing a graph query.""" + """An abstract class representing a graph query.""" @property def query(self): raise NotImplementedError() def __str__(self): - return u''.format(self.query) + return ''.format(self.query) + __repr__ = __str__ @@ -184,6 +216,7 @@ class SimpleGraphStatement(GraphStatement, SimpleStatement): Simple graph statement for :meth:`.Session.execute_graph`. Takes the same parameters as :class:`.SimpleStatement`. """ + @property def query(self): return self._query_string @@ -201,7 +234,7 @@ def graph_result_row_factory(column_names, rows): Returns a :class:`Result ` object that can load graph results and produce specific types. The Result JSON is deserialized and unpacked from the top-level 'result' dict. """ - return [Result(json.loads(row[0])['result']) for row in rows] + return [Result(json.loads(row[0])["result"]) for row in rows] def graph_object_row_factory(column_names, rows): @@ -210,17 +243,17 @@ def graph_object_row_factory(column_names, rows): converted to their simplified objects. Some low-level metadata is shed in this conversion. Unknown result types are still returned as :class:`Result `. """ - return _graph_object_sequence(json.loads(row[0])['result'] for row in rows) + return _graph_object_sequence(json.loads(row[0])["result"] for row in rows) def _graph_object_sequence(objects): for o in objects: res = Result(o) if isinstance(o, dict): - typ = res.value.get('type') - if typ == 'vertex': + typ = res.value.get("type") + if typ == "vertex": res = res.as_vertex() - elif typ == 'edge': + elif typ == "edge": res = res.as_edge() yield res @@ -230,21 +263,23 @@ class _GraphSONContextRowFactory(object): graphson_reader_kwargs = None def __init__(self, cluster): - context = {'cluster': cluster} + context = {"cluster": cluster} kwargs = self.graphson_reader_kwargs or {} self.graphson_reader = self.graphson_reader_class(context, **kwargs) def __call__(self, column_names, rows): - return [self.graphson_reader.read(row[0])['result'] for row in rows] + return [self.graphson_reader.read(row[0])["result"] for row in rows] class _GraphSON2RowFactory(_GraphSONContextRowFactory): """Row factory to deserialize GraphSON2 results.""" + graphson_reader_class = GraphSON2Reader class _GraphSON3RowFactory(_GraphSONContextRowFactory): """Row factory to deserialize GraphSON3 results.""" + graphson_reader_class = GraphSON3Reader @@ -299,7 +334,9 @@ def as_vertex(self): Raises TypeError if parsing fails (i.e. the result structure is not valid). """ try: - return Vertex(self.id, self.label, self.type, self.value.get('properties', {})) + return Vertex( + self.id, self.label, self.type, self.value.get("properties", {}) + ) except (AttributeError, ValueError, TypeError): raise TypeError("Could not create Vertex from %r" % (self,)) @@ -310,8 +347,16 @@ def as_edge(self): Raises TypeError if parsing fails (i.e. the result structure is not valid). """ try: - return Edge(self.id, self.label, self.type, self.value.get('properties', {}), - self.inV, self.inVLabel, self.outV, self.outVLabel) + return Edge( + self.id, + self.label, + self.type, + self.value.get("properties", {}), + self.inV, + self.inVLabel, + self.outV, + self.outVLabel, + ) except (AttributeError, ValueError, TypeError): raise TypeError("Could not create Edge from %r" % (self,)) @@ -327,4 +372,8 @@ def as_path(self): raise TypeError("Could not create Path from %r" % (self,)) def as_vertex_property(self): - return VertexProperty(self.value.get('label'), self.value.get('value'), self.value.get('properties', {})) + return VertexProperty( + self.value.get("label"), + self.value.get("value"), + self.value.get("properties", {}), + ) diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..a12e5d5430 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -34,6 +34,7 @@ from cassandra.util import OrderedDict, _sanitize_identifiers import logging + log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE @@ -49,9 +50,9 @@ Only valid when using native protocol v4+ """ -NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') -START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') -END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') +NON_ALPHA_REGEX = re.compile("[^a-zA-Z0-9]") +START_BADCHAR_REGEX = re.compile("^[^a-zA-Z0-9]*") +END_BADCHAR_REGEX = re.compile("[^a-zA-Z0-9_]*$") _clean_name_cache = {} @@ -60,7 +61,9 @@ def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: - clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) + clean = NON_ALPHA_REGEX.sub( + "_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name)) + ) _clean_name_cache[name] = clean return clean @@ -83,6 +86,7 @@ def tuple_factory(colnames, rows): """ return rows + class PseudoNamedTupleRow(object): """ Helper class for pseudo_named_tuple_factory. These objects provide an @@ -90,6 +94,7 @@ class PseudoNamedTupleRow(object): but otherwise do not attempt to implement the full namedtuple or iterable interface. """ + def __init__(self, ordered_dict): self._dict = ordered_dict self._tuple = tuple(ordered_dict.values()) @@ -104,8 +109,7 @@ def __iter__(self): return iter(self._tuple) def __repr__(self): - return '{t}({od})'.format(t=self.__class__.__name__, - od=self._dict) + return "{t}({od})".format(t=self.__class__.__name__, od=self._dict) def pseudo_namedtuple_factory(colnames, rows): @@ -113,8 +117,7 @@ def pseudo_namedtuple_factory(colnames, rows): Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback factory for cases where :meth:`.named_tuple_factory` fails to create rows. """ - return [PseudoNamedTupleRow(od) - for od in ordered_dict_factory(colnames, rows)] + return [PseudoNamedTupleRow(od) for od in ordered_dict_factory(colnames, rows)] def named_tuple_factory(colnames, rows): @@ -148,7 +151,7 @@ def named_tuple_factory(colnames, rows): """ clean_column_names = map(_clean_column_name, colnames) try: - Row = namedtuple('Row', clean_column_names) + Row = namedtuple("Row", clean_column_names) except SyntaxError: warnings.warn( "Failed creating namedtuple for a result because there were too " @@ -159,19 +162,23 @@ def named_tuple_factory(colnames, rows): "values on row objects, Upgrade to Python 3.7, or use a different " "row factory. (column names: {colnames})".format( substitute_factory_name=pseudo_namedtuple_factory.__name__, - colnames=colnames + colnames=colnames, ) ) return pseudo_namedtuple_factory(colnames, rows) except Exception: - clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt - log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " - "(see Python 'namedtuple' documentation for details on name rules). " - "Results will be returned with positional names. " - "Avoid this by choosing different names, using SELECT \"\" AS aliases, " - "or specifying a different row_factory on your Session" % - (colnames, clean_column_names)) - Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) + clean_column_names = list( + map(_clean_column_name, colnames) + ) # create list because py3 map object will be consumed by first attempt + log.warning( + "Failed creating named tuple for results with column names %s (cleaned: %s) " + "(see Python 'namedtuple' documentation for details on name rules). " + "Results will be returned with positional names. " + 'Avoid this by choosing different names, using SELECT "" AS aliases, ' + "or specifying a different row_factory on your Session" + % (colnames, clean_column_names) + ) + Row = namedtuple("Row", _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] @@ -187,7 +194,7 @@ def dict_factory(colnames, rows): >>> session.row_factory = dict_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print(rows[0]) - {u'age': 42, u'name': u'Bob'} + {'age': 42, 'name': 'Bob'} .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` @@ -276,11 +283,24 @@ class Statement(object): _serial_consistency_level = None _routing_key = None - def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False, table=None): - if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors - raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + def __init__( + self, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + table=None, + ): + if retry_policy and not hasattr( + retry_policy, "on_read_timeout" + ): # just checking one method to detect positional parameter errors + raise ValueError( + "retry_policy should implement cassandra.policies.RetryPolicy" + ) if retry_policy is not None: self.retry_policy = retry_policy if consistency_level is not None: @@ -329,17 +349,20 @@ def _del_routing_key(self): If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. - """) + """, + ) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - if (serial_consistency_level is not None and - not ConsistencyLevel.is_serial(serial_consistency_level)): + if serial_consistency_level is not None and not ConsistencyLevel.is_serial( + serial_consistency_level + ): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " - "or ConsistencyLevel.LOCAL_SERIAL") + "or ConsistencyLevel.LOCAL_SERIAL" + ) self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): @@ -384,7 +407,8 @@ def is_lwt(self): conditional statements. .. versionadded:: 2.0.0 - """) + """, + ) class SimpleStatement(Statement): @@ -392,9 +416,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None, is_idempotent=False): + def __init__( + self, + query_string, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + is_idempotent=False, + ): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the @@ -402,8 +435,17 @@ def __init__(self, query_string, retry_policy=None, consistency_level=None, rout See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + is_idempotent, + ) self._query_string = query_string @property @@ -411,9 +453,14 @@ def query_string(self): return self._query_string def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -442,7 +489,7 @@ class PreparedStatement(object): A note about * in prepared statements """ - column_metadata = None #TODO: make this bind_metadata in next major + column_metadata = None # TODO: make this bind_metadata in next major retry_policy = None consistency_level = None custom_payload = None @@ -459,9 +506,19 @@ class PreparedStatement(object): serial_consistency_level = None # TODO never used? _is_lwt = False - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, - is_lwt=False, column_encryption_policy=None): + def __init__( + self, + column_metadata, + query_id, + routing_key_indexes, + query, + keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt=False, + column_encryption_policy=None, + ): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -475,13 +532,33 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self._is_lwt = is_lwt @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy=None): + def from_message( + cls, + query_id, + column_metadata, + pk_indexes, + cluster_metadata, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy=None, + ): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + return PreparedStatement( + column_metadata, + query_id, + None, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) if pk_indexes: routing_key_indexes = pk_indexes @@ -496,18 +573,32 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement - statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) + statement_indexes = dict( + (c.name, i) for i, c in enumerate(column_metadata) + ) # a list of which indexes in the statement correspond to partition key items try: - routing_key_indexes = [statement_indexes[c.name] - for c in partition_key_columns] - except KeyError: # we're missing a partition key component in the prepared - pass # statement; just leave routing_key_indexes as None - - return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, is_lwt, column_encryption_policy) + routing_key_indexes = [ + statement_indexes[c.name] for c in partition_key_columns + ] + except ( + KeyError + ): # we're missing a partition key component in the prepared + pass # statement; just leave routing_key_indexes as None + + return PreparedStatement( + column_metadata, + query_id, + routing_key_indexes, + query, + prepared_keyspace, + protocol_version, + result_metadata, + result_metadata_id, + is_lwt, + column_encryption_policy, + ) def bind(self, values): """ @@ -519,16 +610,23 @@ def bind(self, values): def is_routing_key_index(self, i): if self._routing_key_index_set is None: - self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() + self._routing_key_index_set = ( + set(self.routing_key_indexes) if self.routing_key_indexes else set() + ) return i in self._routing_key_index_set def is_lwt(self): return self._is_lwt def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.query_string, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.query_string, + consistency, + ) + __repr__ = __str__ @@ -548,9 +646,17 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + def __init__( + self, + prepared_statement, + retry_policy=None, + consistency_level=None, + routing_key=None, + serial_consistency_level=None, + fetch_size=FETCH_SIZE_UNSET, + keyspace=None, + custom_payload=None, + ): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. @@ -571,9 +677,17 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None self.keyspace = meta[0].keyspace_name self.table = meta[0].table_name - Statement.__init__(self, retry_policy, consistency_level, routing_key, - serial_consistency_level, fetch_size, keyspace, custom_payload, - prepared_statement.is_idempotent) + Statement.__init__( + self, + retry_policy, + consistency_level, + routing_key, + serial_consistency_level, + fetch_size, + keyspace, + custom_payload, + prepared_statement.is_idempotent, + ) def bind(self, values): """ @@ -615,24 +729,29 @@ def bind(self, values): values.append(UNSET_VALUE) else: raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + "Column name `%s` not found in bound dict." % (col.name) + ) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( - "Too many arguments provided to bind() (got %d, expected %d)" % - (len(values), len(col_meta))) + "Too many arguments provided to bind() (got %d, expected %d)" + % (len(values), len(col_meta)) + ) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): + if ( + proto_version < 4 + and self.prepared_statement.routing_key_indexes + and value_len < len(self.prepared_statement.routing_key_indexes) + ): raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) + "Too few arguments provided to bind() (got %d, required %d for routing key)" + % (value_len, len(self.prepared_statement.routing_key_indexes)) + ) self.raw_values = values self.values = [] @@ -643,20 +762,30 @@ def bind(self, values): if proto_version >= 4: self._append_unset_value() else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + raise ValueError( + "Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" + % proto_version + ) else: try: - col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + col_desc = ColDesc( + col_spec.keyspace_name, col_spec.table_name, col_spec.name + ) uses_ce = ce_policy and ce_policy.contains_column(col_desc) - col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_type = ( + ce_policy.column_type(col_desc) if uses_ce else col_spec.type + ) col_bytes = col_type.serialize(value, proto_version) if uses_ce: col_bytes = ce_policy.encrypt(col_desc, col_bytes) self.values.append(col_bytes) except (TypeError, struct.error) as exc: actual_type = type(value) - message = ('Received an argument of invalid type for column "%s". ' - 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) + message = ( + 'Received an argument of invalid type for column "%s". ' + "Expected: %s, Got: %s; (%s)" + % (col_spec.name, col_spec.type, actual_type, exc) + ) raise TypeError(message) if proto_version >= 4: @@ -671,7 +800,10 @@ def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] - raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) + raise ValueError( + "Cannot bind UNSET_VALUE as a part of the routing key '%s'" + % col_meta.name + ) self.values.append(UNSET_VALUE) @property @@ -686,7 +818,9 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) + self._routing_key = b"".join( + self._key_parts_packed(self.values[i] for i in routing_indexes) + ) return self._routing_key @@ -694,9 +828,15 @@ def is_lwt(self): return self.prepared_statement.is_lwt() def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.prepared_statement.query_string, self.raw_values, consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return '' % ( + self.prepared_statement.query_string, + self.raw_values, + consistency, + ) + __repr__ = __str__ @@ -731,7 +871,7 @@ def __str__(self): return self.name def __repr__(self): - return "BatchType.%s" % (self.name, ) + return "BatchType.%s" % (self.name,) BatchType.LOGGED = BatchType("LOGGED", 0) @@ -763,9 +903,15 @@ class BatchStatement(Statement): _session = None _is_lwt = False - def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, - consistency_level=None, serial_consistency_level=None, - session=None, custom_payload=None): + def __init__( + self, + batch_type=BatchType.LOGGED, + retry_policy=None, + consistency_level=None, + serial_consistency_level=None, + session=None, + custom_payload=None, + ): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. @@ -813,8 +959,13 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, self.batch_type = batch_type self._statements_and_parameters = [] self._session = session - Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, - serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + Statement.__init__( + self, + retry_policy=retry_policy, + consistency_level=consistency_level, + serial_consistency_level=serial_consistency_level, + custom_payload=custom_payload, + ) def clear(self): """ @@ -853,11 +1004,14 @@ def add(self, statement, parameters=None): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " - "to BatchStatement.add()") + "to BatchStatement.add()" + ) self._update_state(statement) if statement.is_lwt(): self._is_lwt = True - self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) + self._add_statement_and_params( + True, statement.prepared_statement.query_id, statement.values + ) else: # it must be a SimpleStatement query_string = statement.query_string @@ -881,7 +1035,9 @@ def add_all(self, statements, parameters): def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: - raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + raise ValueError( + "Batch statement cannot contain more than %d statements." % 0xFFFF + ) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): @@ -907,9 +1063,15 @@ def __len__(self): return len(self._statements_and_parameters) def __str__(self): - consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') - return (u'' % - (self.batch_type, len(self), consistency)) + consistency = ConsistencyLevel.value_to_name.get( + self.consistency_level, "Not Set" + ) + return "" % ( + self.batch_type, + len(self), + consistency, + ) + __repr__ = __str__ @@ -931,7 +1093,9 @@ def __str__(self): def bind_params(query, params, encoder): if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) + return query % dict( + (k, encoder.cql_encode_all_types(v)) for k, v in params.items() + ) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -940,6 +1104,7 @@ class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ + pass @@ -1000,7 +1165,9 @@ class QueryTrace(object): _session = None - _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" + _SELECT_SESSIONS_FORMAT = ( + "SELECT * FROM system_traces.sessions WHERE session_id = %s" + ) _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 @@ -1029,18 +1196,36 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( - "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) + "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." + % (max_wait,) + ) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) - metadata_request_timeout = self._session.cluster.control_connection and self._session.cluster.control_connection._metadata_request_timeout + metadata_request_timeout = ( + self._session.cluster.control_connection + and self._session.cluster.control_connection._metadata_request_timeout + ) session_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_SESSIONS_FORMAT, metadata_request_timeout), consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_SESSIONS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), + (self.trace_id,), + time_spent, + max_wait, + ) # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries session_row = session_results.one() if session_results else None - is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None + is_complete = ( + session_row is not None + and session_row.duration is not None + and session_row.started_at is not None + ) if not session_results or (wait_for_complete and not is_complete): - time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) + time.sleep(self._BASE_RETRY_SLEEP * (2**attempt)) attempt += 1 continue if is_complete: @@ -1049,29 +1234,42 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) self.request_type = session_row.request - self.duration = timedelta(microseconds=session_row.duration) if is_complete else None + self.duration = ( + timedelta(microseconds=session_row.duration) if is_complete else None + ) self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 - self.client = getattr(session_row, 'client', None) + self.client = getattr(session_row, "client", None) - log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) + log.debug( + "Attempting to fetch trace events for trace ID: %s", self.trace_id + ) time_spent = time.time() - start event_results = self._execute( - SimpleStatement(maybe_add_timeout_to_query(self._SELECT_EVENTS_FORMAT, metadata_request_timeout), - consistency_level=query_cl), + SimpleStatement( + maybe_add_timeout_to_query( + self._SELECT_EVENTS_FORMAT, metadata_request_timeout + ), + consistency_level=query_cl, + ), (self.trace_id,), time_spent, - max_wait) + max_wait, + ) log.debug("Fetched trace events for trace ID: %s", self.trace_id) - self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) - for r in event_results) + self.events = tuple( + TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) + for r in event_results + ) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None - future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) + future = self._session._create_response_future( + query, parameters, trace=False, custom_payload=None, timeout=timeout + ) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() @@ -1079,12 +1277,22 @@ def _execute(self, query, parameters, time_spent, max_wait): try: return future.result() except OperationTimedOut: - raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) + raise TraceUnavailable( + "Trace information was not available within %f seconds" % (max_wait,) + ) def __str__(self): - return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ - % (self.request_type, self.trace_id, self.coordinator, self.started_at, - self.duration, self.parameters) + return ( + "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" + % ( + self.request_type, + self.trace_id, + self.coordinator, + self.started_at, + self.duration, + self.parameters, + ) + ) class TraceEvent(object): @@ -1121,7 +1329,9 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) + self.datetime = datetime.fromtimestamp( + unix_time_from_uuid1(timeuuid), tz=timezone.utc + ) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -1130,7 +1340,12 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.thread_name = thread_name def __str__(self): - return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + return "%s on %s[%s] at %s" % ( + self.description, + self.source, + self.thread_name, + self.datetime, + ) # TODO remove next major since we can target using the `host` attribute of session.execute @@ -1139,9 +1354,12 @@ class HostTargetingStatement(object): Wraps any query statement and attaches a target host, making it usable in a targeted LBP without modifying the user's statement. """ + def __init__(self, inner_statement, target_host): - self.__class__ = type(inner_statement.__class__.__name__, - (self.__class__, inner_statement.__class__), - {}) - self.__dict__ = inner_statement.__dict__ - self.target_host = target_host + self.__class__ = type( + inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}, + ) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host diff --git a/docs/conf.py b/docs/conf.py index 403908c29e..750b89235c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,71 +4,80 @@ from sphinx_scylladb_theme.utils import multiversion_regex_builder -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import cassandra # -- Global variables # Build documentation for the following tags and branches TAGS = [ - '3.21.0-scylla', - '3.22.3-scylla', - '3.24.8-scylla', - '3.25.4-scylla', - '3.25.11-scylla', - '3.26.9-scylla', - '3.28.0-scylla', - '3.28.1-scylla', - '3.28.2-scylla', - '3.29.0-scylla', - '3.29.1-scylla', - '3.29.2-scylla', - '3.29.3-scylla', - '3.29.4-scylla', - '3.29.5-scylla', - '3.29.6-scylla', - '3.29.7-scylla', - '3.29.8-scylla', + "3.21.0-scylla", + "3.22.3-scylla", + "3.24.8-scylla", + "3.25.4-scylla", + "3.25.11-scylla", + "3.26.9-scylla", + "3.28.0-scylla", + "3.28.1-scylla", + "3.28.2-scylla", + "3.29.0-scylla", + "3.29.1-scylla", + "3.29.2-scylla", + "3.29.3-scylla", + "3.29.4-scylla", + "3.29.5-scylla", + "3.29.6-scylla", + "3.29.7-scylla", + "3.29.8-scylla", ] -BRANCHES = ['master'] +BRANCHES = ["master"] # Set the latest version. -LATEST_VERSION = '3.29.8-scylla' +LATEST_VERSION = "3.29.8-scylla" # Set which versions are not released yet. -UNSTABLE_VERSIONS = ['master'] +UNSTABLE_VERSIONS = ["master"] # Set which versions are deprecated -DEPRECATED_VERSIONS = ['3.21.0-scylla', '3.22.3-scylla', '3.24.8-scylla', '3.25.4-scylla', '3.25.11-scylla', '3.26.9-scylla', '3.28.1-scylla', '3.29.1-scylla'] +DEPRECATED_VERSIONS = [ + "3.21.0-scylla", + "3.22.3-scylla", + "3.24.8-scylla", + "3.25.4-scylla", + "3.25.11-scylla", + "3.26.9-scylla", + "3.28.1-scylla", + "3.29.1-scylla", +] -# -- General configuration +# -- General configuration # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.todo', - 'sphinx.ext.mathjax', - 'sphinx.ext.githubpages', - 'sphinx.ext.extlinks', - 'sphinx_sitemap', - 'sphinx_scylladb_theme', - 'sphinx_multiversion', # optional - 'recommonmark', # optional + "sphinx.ext.autodoc", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.githubpages", + "sphinx.ext.extlinks", + "sphinx_sitemap", + "sphinx_scylladb_theme", + "sphinx_multiversion", # optional + "recommonmark", # optional ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. source_suffix = [".rst", ".md"] # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Scylla Python Driver' -copyright = u'ScyllaDB 2021 and © DataStax 2013-2017' +project = "Scylla Python Driver" +copyright = "ScyllaDB 2021 and © DataStax 2013-2017" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -79,8 +88,8 @@ # The full version, including alpha/beta/rc tags. release = cassandra.__version__ -autodoc_member_order = 'bysource' -autoclass_content = 'both' +autodoc_member_order = "bysource" +autoclass_content = "both" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -96,15 +105,15 @@ ] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # -- Options for not found extension # Template used to render the 404.html generated by this extension. -notfound_template = '404.html' +notfound_template = "404.html" # Prefix added to all the URLs generated in the 404 page. -notfound_urls_prefix = '' +notfound_urls_prefix = "" # -- Options for multiversion @@ -115,13 +124,13 @@ # Defines which version is considered to be the latest stable version. smv_latest_version = LATEST_VERSION # Defines the new name for the latest version. -smv_rename_latest_version = 'stable' +smv_rename_latest_version = "stable" # Whitelist pattern for remotes (set to None to use local branches only) -smv_remote_whitelist = r'^origin$' +smv_remote_whitelist = r"^origin$" # Pattern for released versions -smv_released_pattern = r'^tags/.*$' +smv_released_pattern = r"^tags/.*$" # Format for versioned output directories inside the build directory -smv_outputdir_format = '{ref.name}' +smv_outputdir_format = "{ref.name}" # -- Options for sitemap extension @@ -131,40 +140,40 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_scylladb_theme' +html_theme = "sphinx_scylladb_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'conf_py_path': 'docs/', - 'github_repository': 'scylladb/python-driver', - 'github_issues_repository': 'scylladb/python-driver', - 'hide_edit_this_page_button': 'false', - 'hide_version_dropdown': ['master'], - 'hide_feedback_buttons': 'false', - 'versions_unstable': UNSTABLE_VERSIONS, - 'versions_deprecated': DEPRECATED_VERSIONS, + "conf_py_path": "docs/", + "github_repository": "scylladb/python-driver", + "github_issues_repository": "scylladb/python-driver", + "hide_edit_this_page_button": "false", + "hide_version_dropdown": ["master"], + "hide_feedback_buttons": "false", + "versions_unstable": UNSTABLE_VERSIONS, + "versions_deprecated": DEPRECATED_VERSIONS, } # Custom sidebar templates, maps document names to template names. -html_sidebars = {'**': ['side-nav.html']} +html_sidebars = {"**": ["side-nav.html"]} # If false, no index is generated. html_use_index = False # Output file base name for HTML help builder. -htmlhelp_basename = 'CassandraDriverdoc' +htmlhelp_basename = "CassandraDriverdoc" -# URL which points to the root of the HTML documentation. -html_baseurl = 'https://python-driver.docs.scylladb.com' +# URL which points to the root of the HTML documentation. +html_baseurl = "https://python-driver.docs.scylladb.com" # Dictionary of values to pass into the template engine’s context for all pages -html_context = {'html_baseurl': html_baseurl} +html_context = {"html_baseurl": html_baseurl} autodoc_mock_imports = [ # Asyncore has been removed from python 3.12, we need to mock it until `cassandra/io/asyncorereactor.py` is dropped "asyncore", # Since driver is not built, binary modules also not built, so we need to mock them - "cassandra.io.libevwrapper" + "cassandra.io.libevwrapper", ] From 5647fb122759c16eb069488344bc78286608e0b2 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:22:11 +0200 Subject: [PATCH 17/18] Remove u'' string prefixes from test code On Python 3, the u'' prefix is a no-op since all strings are already unicode. These prefixes were left over from Python 2 compatibility. Removed 67 u-prefix occurrences across 9 test files: - tests/unit/test_types.py (3) - tests/unit/test_orderedmap.py (3) - tests/unit/test_marshalling.py (6) - tests/unit/test_metadata.py (27) - tests/integration/standard/test_types.py (4) - tests/integration/standard/test_query.py (13) - tests/integration/standard/test_cluster.py (8) - tests/integration/cqlengine/model/test_udts.py (1) - tests/integration/cqlengine/model/test_model_io.py (2) All 608 unit tests pass. --- .../cqlengine/model/test_model_io.py | 4 +- .../integration/cqlengine/model/test_udts.py | 2 +- tests/integration/standard/test_cluster.py | 16 +- tests/integration/standard/test_query.py | 16 +- tests/integration/standard/test_types.py | 791 ++++++++++++------ tests/unit/test_marshalling.py | 12 +- tests/unit/test_metadata.py | 36 +- tests/unit/test_orderedmap.py | 90 +- tests/unit/test_types.py | 6 +- 9 files changed, 646 insertions(+), 327 deletions(-) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index f55815310a..b6d0af4a7f 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -694,9 +694,9 @@ def setUp(self): def test_query_with_date(self): uid = uuid4() day = date(2013, 11, 26) - obj = TestQueryModel.create(test_id=uid, date=day, description=u'foo') + obj = TestQueryModel.create(test_id=uid, date=day, description='foo') - assert obj.description == u'foo' + assert obj.description == 'foo' inst = TestQueryModel.filter( TestQueryModel.test_id == uid, diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py index 80f1b9693f..de62077c3f 100644 --- a/tests/integration/cqlengine/model/test_udts.py +++ b/tests/integration/cqlengine/model/test_udts.py @@ -383,7 +383,7 @@ def test_udts_with_unicode(self): @test_category data_types:udt """ ascii_name = 'normal name' - unicode_name = u'Fran\u00E7ois' + unicode_name = 'Fran\u00E7ois' class UserModelText(Model): id = columns.Text(primary_key=True) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index bf62f5df48..97ba88ff08 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -1267,14 +1267,14 @@ def test_compact_option(self): "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i)) nc_results = nc_session.execute("SELECT * FROM compact_table") - assert set(nc_results.current_rows) == {(1, u'a1', 11, 11, 'b1'), - (1, u'a11', 11, 11, 'b11'), - (2, u'a2', 22, 22, 'b2'), - (2, u'a22', 22, 22, 'b22'), - (3, u'a3', 33, 33, 'b3'), - (3, u'a33', 33, 33, 'b33'), - (4, u'a4', 44, 44, 'b4'), - (4, u'a44', 44, 44, 'b44')} + assert set(nc_results.current_rows) == {(1, 'a1', 11, 11, 'b1'), + (1, 'a11', 11, 11, 'b11'), + (2, 'a2', 22, 22, 'b2'), + (2, 'a22', 22, 22, 'b22'), + (3, 'a3', 33, 33, 'b3'), + (3, 'a33', 33, 33, 'b33'), + (4, 'a4', 44, 44, 'b4'), + (4, 'a44', 44, 44, 'b44')} results = session.execute("SELECT * FROM compact_table") assert set(results.current_rows) == {(1, 11, 11), diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 9cebc22b05..ec2309a674 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -323,7 +323,7 @@ def test_column_names(self): result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) assert result_set.column_types is not None - assert result_set.column_names == [u'user', u'game', u'year', u'month', u'day', u'score'] + assert result_set.column_names == ['user', 'game', 'year', 'month', 'day', 'score'] @greaterthanorequalcass30 def test_basic_json_query(self): @@ -759,11 +759,11 @@ def test_unicode(self): k int PRIMARY KEY, v text )''' self.session.execute(ddl) - unicode_text = u'Fran\u00E7ois' - query = u'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)' + unicode_text = 'Fran\u00E7ois' + query = 'INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)' try: batch = BatchStatement(BatchType.LOGGED) - batch.add(u"INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text)) + batch.add("INSERT INTO test3rf.testtext (k, v) VALUES (%s, %s)", (0, unicode_text)) self.session.execute(batch) finally: self.session.execute("DROP TABLE test3rf.testtext") @@ -1338,12 +1338,12 @@ def test_unicode(self): @test_category query """ - unicode_text = u'Fran\u00E7ois' + unicode_text = 'Fran\u00E7ois' batch = BatchStatement(BatchType.LOGGED) - batch.add(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) + batch.add("INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) self.session.execute(batch) - self.session.execute(u"INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) - prepared = self.session.prepare(u"INSERT INTO {0}.{1} (k, v) VALUES (?, ?)".format(self.keyspace_name, self.function_table_name)) + self.session.execute("INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(self.keyspace_name, self.function_table_name), (0, unicode_text)) + prepared = self.session.prepare("INSERT INTO {0}.{1} (k, v) VALUES (?, ?)".format(self.keyspace_name, self.function_table_name)) bound = prepared.bind((1, unicode_text)) self.session.execute(bound) diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 1d66ce1ed9..846096982f 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -38,12 +38,27 @@ from tests.unit.cython.utils import cythontest from tests.util import assertEqual -from tests.integration import use_singledc, execute_until_pass, notprotocolv1, \ - BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, \ - greaterthanorequalcass3_10, TestCluster, requires_composite_type, \ - requires_vector_type -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \ - get_sample, get_all_samples, get_collection_sample +from tests.integration import ( + use_singledc, + execute_until_pass, + notprotocolv1, + BasicSharedKeyspaceUnitTestCase, + greaterthancass21, + lessthancass30, + greaterthanorequalcass3_10, + TestCluster, + requires_composite_type, + requires_vector_type, +) +from tests.integration.datatype_utils import ( + update_datatypes, + PRIMITIVE_DATATYPES, + COLLECTION_TYPES, + PRIMITIVE_DATATYPES_KEYS, + get_sample, + get_all_samples, + get_collection_sample, +) import pytest @@ -53,7 +68,6 @@ def setup_module(): class TypeTests(BasicSharedKeyspaceUnitTestCase): - @classmethod def setUpClass(cls): # cls._cass_version, cls. = get_server_versions() @@ -68,7 +82,7 @@ def test_can_insert_blob_type_as_string(self): s.execute("CREATE TABLE blobstring (a ascii PRIMARY KEY, b blob)") - params = ['key1', b'blobbyblob'] + params = ["key1", b"blobbyblob"] query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)" s.execute(query, params) @@ -85,14 +99,17 @@ def test_can_insert_blob_type_as_bytearray(self): s.execute("CREATE TABLE blobbytes (a ascii PRIMARY KEY, b blob)") - params = ['key1', bytearray(b'blob1')] + params = ["key1", bytearray(b"blob1")] s.execute("INSERT INTO blobbytes (a, b) VALUES (%s, %s)", params) results = s.execute("SELECT * FROM blobbytes").one() for expected, actual in zip(params, results): assert expected == actual - @unittest.skipIf(not hasattr(cassandra, 'deserializers'), "Cython required for to test DesBytesTypeArray deserializer") + @unittest.skipIf( + not hasattr(cassandra, "deserializers"), + "Cython required for to test DesBytesTypeArray deserializer", + ) def test_des_bytes_type_array(self): """ Simple test to ensure the DesBytesTypeByteArray deserializer functionally works @@ -105,14 +122,15 @@ def test_des_bytes_type_array(self): """ original = None try: - original = cassandra.deserializers.DesBytesType - cassandra.deserializers.DesBytesType = cassandra.deserializers.DesBytesTypeByteArray + cassandra.deserializers.DesBytesType = ( + cassandra.deserializers.DesBytesTypeByteArray + ) s = self.session s.execute("CREATE TABLE blobbytes2 (a ascii PRIMARY KEY, b blob)") - params = ['key1', bytearray(b'blob1')] + params = ["key1", bytearray(b"blob1")] s.execute("INSERT INTO blobbytes2 (a, b) VALUES (%s, %s)", params) results = s.execute("SELECT * FROM blobbytes2").one() @@ -120,7 +138,7 @@ def test_des_bytes_type_array(self): assert expected == actual finally: if original is not None: - cassandra.deserializers.DesBytesType=original + cassandra.deserializers.DesBytesType = original def test_can_insert_primitive_datatypes(self): """ @@ -132,12 +150,12 @@ def test_can_insert_primitive_datatypes(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = ["zz"] - start_index = ord('a') + start_index = ord("a") for i, datatype in enumerate(PRIMITIVE_DATATYPES): alpha_type_list.append("{0} {1}".format(chr(start_index + i), datatype)) col_names.append(chr(start_index + i)) - s.execute("CREATE TABLE alltypes ({0})".format(', '.join(alpha_type_list))) + s.execute("CREATE TABLE alltypes ({0})".format(", ".join(alpha_type_list))) # create the input params = [0] @@ -145,12 +163,19 @@ def test_can_insert_primitive_datatypes(self): params.append((get_sample(datatype))) # insert into table as a simple statement - columns_string = ', '.join(col_names) - placeholders = ', '.join(["%s"] * len(col_names)) - s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + columns_string = ", ".join(col_names) + placeholders = ", ".join(["%s"] * len(col_names)) + s.execute( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ), + params, + ) # verify data - results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM alltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -159,29 +184,46 @@ def test_can_insert_primitive_datatypes(self): for i, datatype in enumerate(PRIMITIVE_DATATYPES): single_col_name = chr(start_index + i) single_col_names = ["zz", single_col_name] - placeholders = ','.join(["%s"] * len(single_col_names)) - single_columns_string = ', '.join(single_col_names) + placeholders = ",".join(["%s"] * len(single_col_names)) + single_columns_string = ", ".join(single_col_names) for j, data_sample in enumerate(get_all_samples(datatype)): key = i + 1000 * j single_params = (key, data_sample) - s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(single_columns_string, placeholders), - single_params) + s.execute( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + single_columns_string, placeholders + ), + single_params, + ) # verify data - result = s.execute("SELECT {0} FROM alltypes WHERE zz=%s".format(single_columns_string), (key,)).one()[1] + result = s.execute( + "SELECT {0} FROM alltypes WHERE zz=%s".format( + single_columns_string + ), + (key,), + ).one()[1] compare_value = data_sample - if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address): + if isinstance(data_sample, ipaddress.IPv4Address) or isinstance( + data_sample, ipaddress.IPv6Address + ): compare_value = str(data_sample) assert result == compare_value # try the same thing with a prepared statement - placeholders = ','.join(["?"] * len(col_names)) + placeholders = ",".join(["?"] * len(col_names)) s.execute("TRUNCATE alltypes;") - insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + insert = s.prepare( + "INSERT INTO alltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ) + ) s.execute(insert.bind(params)) # verify data - results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM alltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -193,8 +235,12 @@ def test_can_insert_primitive_datatypes(self): # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM alltypes") - results = s.execute(select, - execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory)).one() + results = s.execute( + select, + execution_profile=s.execution_profile_clone_update( + EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory + ), + ).one() for expected, actual in zip(params, results.values()): assert actual == expected @@ -214,23 +260,37 @@ def test_can_insert_collection_datatypes(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = ["zz"] - start_index = ord('a') + start_index = ord("a") for i, collection_type in enumerate(COLLECTION_TYPES): for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): if collection_type == "map": - type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}, {3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) elif collection_type == "tuple": - type_string = "{0}_{1} frozen<{2}<{3}>>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} frozen<{2}<{3}>>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) else: - type_string = "{0}_{1} {2}<{3}>".format(chr(start_index + i), chr(start_index + j), - collection_type, datatype) + type_string = "{0}_{1} {2}<{3}>".format( + chr(start_index + i), + chr(start_index + j), + collection_type, + datatype, + ) alpha_type_list.append(type_string) - col_names.append("{0}_{1}".format(chr(start_index + i), chr(start_index + j))) + col_names.append( + "{0}_{1}".format(chr(start_index + i), chr(start_index + j)) + ) - s.execute("CREATE TABLE allcoltypes ({0})".format(', '.join(alpha_type_list))) - columns_string = ', '.join(col_names) + s.execute("CREATE TABLE allcoltypes ({0})".format(", ".join(alpha_type_list))) + columns_string = ", ".join(col_names) # create the input for simple statement params = [0] @@ -239,11 +299,18 @@ def test_can_insert_collection_datatypes(self): params.append((get_collection_sample(collection_type, datatype))) # insert into table as a simple statement - placeholders = ', '.join(["%s"] * len(col_names)) - s.execute("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders), params) + placeholders = ", ".join(["%s"] * len(col_names)) + s.execute( + "INSERT INTO allcoltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ), + params, + ) # verify data - results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected @@ -254,26 +321,37 @@ def test_can_insert_collection_datatypes(self): params.append((get_collection_sample(collection_type, datatype))) # try the same thing with a prepared statement - placeholders = ','.join(["?"] * len(col_names)) - insert = s.prepare("INSERT INTO allcoltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) + placeholders = ",".join(["?"] * len(col_names)) + insert = s.prepare( + "INSERT INTO allcoltypes ({0}) VALUES ({1})".format( + columns_string, placeholders + ) + ) s.execute(insert.bind(params)) # verify data - results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string) + ).one() for expected, actual in zip(params, results): assert actual == expected # verify data with prepared statement query - select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([0])).one() for expected, actual in zip(params, results): assert actual == expected # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM allcoltypes") - results = s.execute(select, - execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, - row_factory=ordered_dict_factory)).one() + results = s.execute( + select, + execution_profile=s.execution_profile_clone_update( + EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory + ), + ).one() for expected, actual in zip(params, results.values()): assert actual == expected @@ -289,12 +367,16 @@ def test_can_insert_empty_strings_and_nulls(self): # create table alpha_type_list = ["zz int PRIMARY KEY"] col_names = [] - string_types = set(('ascii', 'text', 'varchar')) - string_columns = set(('')) + string_types = set(("ascii", "text", "varchar")) + string_columns = set(("")) # this is just a list of types to try with empty strings - non_string_types = PRIMITIVE_DATATYPES - string_types - set(('blob', 'date', 'inet', 'time', 'timestamp')) + non_string_types = ( + PRIMITIVE_DATATYPES + - string_types + - set(("blob", "date", "inet", "time", "timestamp")) + ) non_string_columns = set() - start_index = ord('a') + start_index = ord("a") for i, datatype in enumerate(PRIMITIVE_DATATYPES): col_name = chr(start_index + i) alpha_type_list.append("{0} {1}".format(col_name, datatype)) @@ -304,32 +386,48 @@ def test_can_insert_empty_strings_and_nulls(self): if datatype in string_types: string_columns.add(col_name) - execute_until_pass(s, "CREATE TABLE all_empty ({0})".format(', '.join(alpha_type_list))) + execute_until_pass( + s, "CREATE TABLE all_empty ({0})".format(", ".join(alpha_type_list)) + ) # verify all types initially null with simple statement - columns_string = ','.join(col_names) + columns_string = ",".join(col_names) s.execute("INSERT INTO all_empty (zz) VALUES (2)") - results = s.execute("SELECT {0} FROM all_empty WHERE zz=2".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM all_empty WHERE zz=2".format(columns_string) + ).one() assert all(x is None for x in results) # verify all types initially null with prepared statement - select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM all_empty WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([2])).one() assert all(x is None for x in results) # insert empty strings for string-like fields - expected_values = dict((col, '') for col in string_columns) - columns_string = ','.join(string_columns) - placeholders = ','.join(["%s"] * len(string_columns)) - s.execute("INSERT INTO all_empty (zz, {0}) VALUES (3, {1})".format(columns_string, placeholders), expected_values.values()) + expected_values = dict((col, "") for col in string_columns) + columns_string = ",".join(string_columns) + placeholders = ",".join(["%s"] * len(string_columns)) + s.execute( + "INSERT INTO all_empty (zz, {0}) VALUES (3, {1})".format( + columns_string, placeholders + ), + expected_values.values(), + ) # verify string types empty with simple statement - results = s.execute("SELECT {0} FROM all_empty WHERE zz=3".format(columns_string)).one() + results = s.execute( + "SELECT {0} FROM all_empty WHERE zz=3".format(columns_string) + ).one() for expected, actual in zip(expected_values.values(), results): assert actual == expected # verify string types empty with prepared statement - results = s.execute(s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), [3]).one() + results = s.execute( + s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), + [3], + ).one() for expected, actual in zip(expected_values.values(), results): assert actual == expected @@ -337,11 +435,13 @@ def test_can_insert_empty_strings_and_nulls(self): for col in non_string_columns: query = "INSERT INTO all_empty (zz, {0}) VALUES (4, %s)".format(col) with pytest.raises(InvalidRequest): - s.execute(query, ['']) + s.execute(query, [""]) - insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col)) + insert = s.prepare( + "INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col) + ) with pytest.raises(TypeError): - s.execute(insert, ['']) + s.execute(insert, [""]) # verify that Nones can be inserted and overwrites existing data # create the input @@ -350,9 +450,11 @@ def test_can_insert_empty_strings_and_nulls(self): params.append((get_sample(datatype))) # insert the data - columns_string = ','.join(col_names) - placeholders = ','.join(["%s"] * len(col_names)) - simple_insert = "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders) + columns_string = ",".join(col_names) + placeholders = ",".join(["%s"] * len(col_names)) + simple_insert = "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format( + columns_string, placeholders + ) s.execute(simple_insert, params) # then insert None, which should null them out @@ -366,7 +468,9 @@ def test_can_insert_empty_strings_and_nulls(self): assert None == col # check via prepared statement - select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) + select = s.prepare( + "SELECT {0} FROM all_empty WHERE zz=?".format(columns_string) + ) results = s.execute(select.bind([5])).one() for col in results: assert None == col @@ -374,8 +478,12 @@ def test_can_insert_empty_strings_and_nulls(self): # do the same thing again, but use a prepared statement to insert the nulls s.execute(simple_insert, params) - placeholders = ','.join(["?"] * len(col_names)) - insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format(columns_string, placeholders)) + placeholders = ",".join(["?"] * len(col_names)) + insert = s.prepare( + "INSERT INTO all_empty (zz, {0}) VALUES (5, {1})".format( + columns_string, placeholders + ) + ) s.execute(insert, null_values) results = s.execute(query).one() @@ -393,10 +501,14 @@ def test_can_insert_empty_values_for_int32(self): s = self.session execute_until_pass(s, "CREATE TABLE empty_values (a text PRIMARY KEY, b int)") - execute_until_pass(s, "INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))") + execute_until_pass( + s, "INSERT INTO empty_values (a, b) VALUES ('a', blobAsInt(0x))" + ) try: Int32Type.support_empty_values = True - results = execute_until_pass(s, "SELECT b FROM empty_values WHERE a='a'").one() + results = execute_until_pass( + s, "SELECT b FROM empty_values WHERE a='a'" + ).one() assert EMPTY is results.b finally: Int32Type.support_empty_values = False @@ -408,7 +520,7 @@ def test_timezone_aware_datetimes_are_timestamps(self): from zoneinfo import ZoneInfo - eastern_tz = ZoneInfo('US/Eastern') + eastern_tz = ZoneInfo("US/Eastern") dt = datetime(1997, 8, 29, 11, 14, tzinfo=eastern_tz) s = self.session @@ -440,24 +552,30 @@ def test_can_insert_tuples(self): # use this encoder in order to insert tuples s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - s.execute("CREATE TABLE tuple_type (a int PRIMARY KEY, b frozen>)") + s.execute( + "CREATE TABLE tuple_type (a int PRIMARY KEY, b frozen>)" + ) # test non-prepared statement - complete = ('foo', 123, True) - s.execute("INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,)) + complete = ("foo", 123, True) + s.execute( + "INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,) + ) result = s.execute("SELECT b FROM tuple_type WHERE a=0").one() assert complete == result.b - partial = ('bar', 456) + partial = ("bar", 456) partial_result = partial + (None,) s.execute("INSERT INTO tuple_type (a, b) VALUES (1, %s)", parameters=(partial,)) result = s.execute("SELECT b FROM tuple_type WHERE a=1").one() assert partial_result == result.b # test single value tuples - subpartial = ('zoo',) + subpartial = ("zoo",) subpartial_result = subpartial + (None, None) - s.execute("INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,)) + s.execute( + "INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,) + ) result = s.execute("SELECT b FROM tuple_type WHERE a=2").one() assert subpartial_result == result.b @@ -488,7 +606,9 @@ def test_can_insert_tuples_with_varying_lengths(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -499,8 +619,11 @@ def test_can_insert_tuples_with_varying_lengths(self): lengths = (1, 2, 3, 384) value_schema = [] for i in lengths: - value_schema += [' v_%s frozen>' % (i, ', '.join(['int'] * i))] - s.execute("CREATE TABLE tuple_lengths (k int PRIMARY KEY, %s)" % (', '.join(value_schema),)) + value_schema += [" v_%s frozen>" % (i, ", ".join(["int"] * i))] + s.execute( + "CREATE TABLE tuple_lengths (k int PRIMARY KEY, %s)" + % (", ".join(value_schema),) + ) # insert tuples into same key using different columns # and verify the results @@ -508,15 +631,20 @@ def test_can_insert_tuples_with_varying_lengths(self): # ensure tuples of larger sizes throw an error created_tuple = tuple(range(0, i + 1)) with pytest.raises(InvalidRequest): - s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) # ensure tuples of proper sizes are written and read correctly created_tuple = tuple(range(0, i)) - s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple) + ) result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,)).one() - assert tuple(created_tuple) == result['v_%s' % i] + assert tuple(created_tuple) == result["v_%s" % i] c.shutdown() def test_can_insert_tuples_all_primitive_datatypes(self): @@ -531,9 +659,11 @@ def test_can_insert_tuples_all_primitive_datatypes(self): s = c.connect(self.keyspace_name) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple - s.execute("CREATE TABLE tuple_primitive (" - "k int PRIMARY KEY, " - "v frozen>)" % ','.join(PRIMITIVE_DATATYPES)) + s.execute( + "CREATE TABLE tuple_primitive (" + "k int PRIMARY KEY, " + "v frozen>)" % ",".join(PRIMITIVE_DATATYPES) + ) values = [] type_count = len(PRIMITIVE_DATATYPES) @@ -542,7 +672,9 @@ def test_can_insert_tuples_all_primitive_datatypes(self): # responses have trailing None values for every element that has not been written values.append(get_sample(data_type)) expected = tuple(values + [None] * (type_count - len(values))) - s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values))) + s.execute( + "INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values)) + ) result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,)).one() assert result.v == expected c.shutdown() @@ -556,7 +688,9 @@ def test_can_insert_tuples_all_collection_datatypes(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -567,62 +701,87 @@ def test_can_insert_tuples_all_collection_datatypes(self): # create list values for datatype in PRIMITIVE_DATATYPES_KEYS: - values.append('v_{0} frozen>>'.format(len(values), datatype)) + values.append( + "v_{0} frozen>>".format(len(values), datatype) + ) # create set values for datatype in PRIMITIVE_DATATYPES_KEYS: - values.append('v_{0} frozen>>'.format(len(values), datatype)) + values.append("v_{0} frozen>>".format(len(values), datatype)) # create map values for datatype in PRIMITIVE_DATATYPES_KEYS: datatype_1 = datatype_2 = datatype - if datatype == 'blob': + if datatype == "blob": # unhashable type: 'bytearray' - datatype_1 = 'ascii' - values.append('v_{0} frozen>>'.format(len(values), datatype_1, datatype_2)) + datatype_1 = "ascii" + values.append( + "v_{0} frozen>>".format( + len(values), datatype_1, datatype_2 + ) + ) # make sure we're testing all non primitive data types in the future - if set(COLLECTION_TYPES) != set(['tuple', 'list', 'map', 'set']): - raise NotImplemented('Missing datatype not implemented: {}'.format( - set(COLLECTION_TYPES) - set(['tuple', 'list', 'map', 'set']) - )) + if set(COLLECTION_TYPES) != set(["tuple", "list", "map", "set"]): + raise NotImplemented( + "Missing datatype not implemented: {}".format( + set(COLLECTION_TYPES) - set(["tuple", "list", "map", "set"]) + ) + ) # create table - s.execute("CREATE TABLE tuple_non_primative (" - "k int PRIMARY KEY, " - "%s)" % ', '.join(values)) + s.execute( + "CREATE TABLE tuple_non_primative (" + "k int PRIMARY KEY, " + "%s)" % ", ".join(values) + ) i = 0 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([[get_sample(datatype)]]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) - - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) + + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([sortedset([get_sample(datatype)])]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) - - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) + + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 # test tuple> for datatype in PRIMITIVE_DATATYPES_KEYS: - if datatype == 'blob': + if datatype == "blob": # unhashable type: 'bytearray' - created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) + created_tuple = tuple([{get_sample("ascii"): get_sample(datatype)}]) else: created_tuple = tuple([{get_sample(datatype): get_sample(datatype)}]) - s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + s.execute( + "INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", + (i, created_tuple), + ) - result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - assert created_tuple == result['v_%s' % i] + result = s.execute( + "SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,) + ).one() + assert created_tuple == result["v_%s" % i] i += 1 c.shutdown() @@ -632,9 +791,9 @@ def nested_tuples_schema_helper(self, depth): """ if depth == 0: - return 'int' + return "int" else: - return 'tuple<%s>' % self.nested_tuples_schema_helper(depth - 1) + return "tuple<%s>" % self.nested_tuples_schema_helper(depth - 1) def nested_tuples_creator_helper(self, depth): """ @@ -644,7 +803,7 @@ def nested_tuples_creator_helper(self, depth): if depth == 0: return 303 else: - return (self.nested_tuples_creator_helper(depth - 1), ) + return (self.nested_tuples_creator_helper(depth - 1),) def test_can_insert_nested_tuples(self): """ @@ -655,7 +814,9 @@ def test_can_insert_nested_tuples(self): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") c = TestCluster( - execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory) + } ) s = c.connect(self.keyspace_name) @@ -663,27 +824,37 @@ def test_can_insert_nested_tuples(self): s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple # create a table with multiple sizes of nested tuples - s.execute("CREATE TABLE nested_tuples (" - "k int PRIMARY KEY, " - "v_1 frozen<%s>," - "v_2 frozen<%s>," - "v_3 frozen<%s>," - "v_32 frozen<%s>" - ")" % (self.nested_tuples_schema_helper(1), - self.nested_tuples_schema_helper(2), - self.nested_tuples_schema_helper(3), - self.nested_tuples_schema_helper(32))) + s.execute( + "CREATE TABLE nested_tuples (" + "k int PRIMARY KEY, " + "v_1 frozen<%s>," + "v_2 frozen<%s>," + "v_3 frozen<%s>," + "v_32 frozen<%s>" + ")" + % ( + self.nested_tuples_schema_helper(1), + self.nested_tuples_schema_helper(2), + self.nested_tuples_schema_helper(3), + self.nested_tuples_schema_helper(32), + ) + ) for i in (1, 2, 3, 32): # create tuple created_tuple = self.nested_tuples_creator_helper(i) # write tuple - s.execute("INSERT INTO nested_tuples (k, v_%s) VALUES (%s, %s)", (i, i, created_tuple)) + s.execute( + "INSERT INTO nested_tuples (k, v_%s) VALUES (%s, %s)", + (i, i, created_tuple), + ) # verify tuple was written and read correctly - result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i)).one() - assert created_tuple == result['v_%s' % i] + result = s.execute( + "SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i) + ).one() + assert created_tuple == result["v_%s" % i] c.shutdown() def test_can_insert_tuples_with_nulls(self): @@ -696,7 +867,9 @@ def test_can_insert_tuples_with_nulls(self): s = self.session - s.execute("CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)") + s.execute( + "CREATE TABLE tuples_nulls (k int PRIMARY KEY, t frozen>)" + ) insert = s.prepare("INSERT INTO tuples_nulls (k, t) VALUES (0, ?)") s.execute(insert, [(None, None, None, None)]) @@ -708,10 +881,10 @@ def test_can_insert_tuples_with_nulls(self): assert (None, None, None, None) == s.execute(read).one().t # also test empty strings where compatible - s.execute(insert, [('', None, None, b'')]) + s.execute(insert, [("", None, None, b"")]) result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") - assert ('', None, None, b'') == result.one().t - assert ('', None, None, b'') == s.execute(read).one().t + assert ("", None, None, b"") == result.one().t + assert ("", None, None, b"") == s.execute(read).one().t def test_insert_collection_with_null_fails(self): """ @@ -721,52 +894,62 @@ def test_insert_collection_with_null_fails(self): """ s = self.session columns = [] - for collection_type in ['list', 'set']: + for collection_type in ["list", "set"]: for simple_type in PRIMITIVE_DATATYPES_KEYS: - columns.append(f'{collection_type}_{simple_type} {collection_type}<{simple_type}>') + columns.append( + f"{collection_type}_{simple_type} {collection_type}<{simple_type}>" + ) for simple_type in PRIMITIVE_DATATYPES_KEYS: - columns.append(f'map_k_{simple_type} map<{simple_type}, ascii>') - columns.append(f'map_v_{simple_type} map') - s.execute(f'CREATE TABLE collection_nulls (k int PRIMARY KEY, {", ".join(columns)})') + columns.append(f"map_k_{simple_type} map<{simple_type}, ascii>") + columns.append(f"map_v_{simple_type} map") + s.execute( + f"CREATE TABLE collection_nulls (k int PRIMARY KEY, {', '.join(columns)})" + ) def raises_simple_and_prepared(exc_type, query_str, args): with pytest.raises(exc_type): s.execute(query_str, args) - p = s.prepare(query_str.replace('%s', '?')) + p = s.prepare(query_str.replace("%s", "?")) with pytest.raises(exc_type): s.execute(p, args) i = 0 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, set_{simple_type}) VALUES (%s, %s)' + query_str = ( + f"INSERT INTO collection_nulls (k, set_{simple_type}) VALUES (%s, %s)" + ) args = [i, sortedset([None, get_sample(simple_type)])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, list_{simple_type}) VALUES (%s, %s)' + query_str = ( + f"INSERT INTO collection_nulls (k, list_{simple_type}) VALUES (%s, %s)" + ) args = [i, [None, get_sample(simple_type)]] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, map_k_{simple_type}) VALUES (%s, %s)' - args = [i, OrderedMap([(get_sample(simple_type), 'abc'), (None, 'def')])] + query_str = ( + f"INSERT INTO collection_nulls (k, map_k_{simple_type}) VALUES (%s, %s)" + ) + args = [i, OrderedMap([(get_sample(simple_type), "abc"), (None, "def")])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 for simple_type in PRIMITIVE_DATATYPES_KEYS: - query_str = f'INSERT INTO collection_nulls (k, map_v_{simple_type}) VALUES (%s, %s)' - args = [i, OrderedMap([('abc', None), ('def', get_sample(simple_type))])] + query_str = ( + f"INSERT INTO collection_nulls (k, map_v_{simple_type}) VALUES (%s, %s)" + ) + args = [i, OrderedMap([("abc", None), ("def", get_sample(simple_type))])] raises_simple_and_prepared(InvalidRequest, query_str, args) i += 1 - - def test_can_insert_unicode_query_string(self): """ Test to ensure unicode strings can be used in a query """ s = self.session - s.execute(u"SELECT * FROM system.local WHERE key = 'ef\u2052ef'") - s.execute(u"SELECT * FROM system.local WHERE key = %s", (u"fe\u2051fe",)) + s.execute("SELECT * FROM system.local WHERE key = 'ef\u2052ef'") + s.execute("SELECT * FROM system.local WHERE key = %s", ("fe\u2051fe",)) @requires_composite_type def test_can_read_composite_type(self): @@ -785,13 +968,13 @@ def test_can_read_composite_type(self): s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc:123')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() assert 0 == result.a - assert ('abc', 123) == result.b + assert ("abc", 123) == result.b # CompositeType values can omit elements at the end s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() assert 0 == result.a - assert ('abc',) == result.b + assert ("abc",) == result.b @notprotocolv1 def test_special_float_cql_encoding(self): @@ -811,7 +994,7 @@ def test_special_float_cql_encoding(self): f float PRIMARY KEY, d double )""") - items = (float('nan'), float('inf'), float('-inf')) + items = (float("nan"), float("inf"), float("-inf")) def verify_insert_select(ins_statement, sel_statement): execute_concurrent_with_args(s, ins_statement, ((f, f) for f in items)) @@ -825,14 +1008,18 @@ def verify_insert_select(ins_statement, sel_statement): assert row.d == f # cql encoding - verify_insert_select('INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)', - 'SELECT * FROM float_cql_encoding WHERE f=%s') + verify_insert_select( + "INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)", + "SELECT * FROM float_cql_encoding WHERE f=%s", + ) s.execute("TRUNCATE float_cql_encoding") # prepared binding - verify_insert_select(s.prepare('INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)'), - s.prepare('SELECT * FROM float_cql_encoding WHERE f=?')) + verify_insert_select( + s.prepare("INSERT INTO float_cql_encoding (f, d) VALUES (?, ?)"), + s.prepare("SELECT * FROM float_cql_encoding WHERE f=?"), + ) @cythontest def test_cython_decimal(self): @@ -846,11 +1033,19 @@ def test_cython_decimal(self): @test_category data_types serialization """ - self.session.execute("CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name)) + self.session.execute( + "CREATE TABLE {0} (dc decimal PRIMARY KEY)".format(self.function_table_name) + ) try: - self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name)) - results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) - assert str(results.one().dc) == '-1.08430792318105707' + self.session.execute( + "INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format( + self.function_table_name + ) + ) + results = self.session.execute( + "SELECT * FROM {0}".format(self.function_table_name) + ) + assert str(results.one().dc) == "-1.08430792318105707" finally: self.session.execute("DROP TABLE {0}".format(self.function_table_name)) @@ -877,37 +1072,66 @@ def test_smoke_duration_values(self): VALUES (?, ?) """) - nanosecond_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, - 10000000000000,-9223372036854775807, 9223372036854775807, - int("7FFFFFFFFFFFFFFF", 16), int("-7FFFFFFFFFFFFFFF", 16)] - month_day_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, - int("7FFFFFFF", 16), int("-7FFFFFFF", 16)] + nanosecond_smoke_values = [ + 0, + -1, + 1, + 100, + 1000, + 1000000, + 1000000000, + 10000000000000, + -9223372036854775807, + 9223372036854775807, + int("7FFFFFFFFFFFFFFF", 16), + int("-7FFFFFFFFFFFFFFF", 16), + ] + month_day_smoke_values = [ + 0, + -1, + 1, + 100, + 1000, + 1000000, + 1000000000, + int("7FFFFFFF", 16), + int("-7FFFFFFF", 16), + ] for nanosecond_value in nanosecond_smoke_values: for month_day_value in month_day_smoke_values: - # Must have the same sign if (month_day_value <= 0) != (nanosecond_value <= 0): continue - self.session.execute(prepared, (1, Duration(month_day_value, month_day_value, nanosecond_value))) + self.session.execute( + prepared, + (1, Duration(month_day_value, month_day_value, nanosecond_value)), + ) results = self.session.execute("SELECT * FROM duration_smoke") v = results.one()[1] - assert Duration(month_day_value, month_day_value, nanosecond_value) == v, "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value) + assert ( + Duration(month_day_value, month_day_value, nanosecond_value) == v + ), "Error encoding value {0},{0},{1}".format( + month_day_value, nanosecond_value + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16)))) + self.session.execute( + prepared, (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16))) + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0))) + self.session.execute( + prepared, (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0)) + ) with pytest.raises(ValueError): - self.session.execute(prepared, - (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0))) + self.session.execute( + prepared, (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0)) + ) -class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): +class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): @greaterthancass21 @lessthancass30 def test_nested_types_with_protocol_version(self): @@ -921,26 +1145,30 @@ def test_nested_types_with_protocol_version(self): @test_category data_types serialization """ - ddl = '''CREATE TABLE {0}.t ( + ddl = """CREATE TABLE {0}.t ( k int PRIMARY KEY, - v list>>)'''.format(self.keyspace_name) + v list>>)""".format(self.keyspace_name) self.session.execute(ddl) - ddl = '''CREATE TABLE {0}.u ( + ddl = """CREATE TABLE {0}.u ( k int PRIMARY KEY, - v set>>)'''.format(self.keyspace_name) + v set>>)""".format(self.keyspace_name) self.session.execute(ddl) - ddl = '''CREATE TABLE {0}.v ( + ddl = """CREATE TABLE {0}.v ( k int PRIMARY KEY, v map>, frozen>>, - v1 frozen>)'''.format(self.keyspace_name) + v1 frozen>)""".format(self.keyspace_name) self.session.execute(ddl) - self.session.execute("CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format(self.keyspace_name)) + self.session.execute( + "CREATE TYPE {0}.typ (v0 frozen>>>, v1 frozen>)".format( + self.keyspace_name + ) + ) - ddl = '''CREATE TABLE {0}.w ( + ddl = """CREATE TABLE {0}.w ( k int PRIMARY KEY, - v frozen)'''.format(self.keyspace_name) + v frozen)""".format(self.keyspace_name) self.session.execute(ddl) @@ -952,17 +1180,23 @@ def test_nested_types_with_protocol_version(self): def read_inserts_at_level(self, proto_ver): session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: - results = session.execute('select * from t').one() + results = session.execute("select * from t").one() assert "[SortedSet([1, 2]), SortedSet([3, 5])]" == str(results.v) - results = session.execute('select * from u').one() + results = session.execute("select * from u").one() assert "SortedSet([[1, 2], [3, 5]])" == str(results.v) - results = session.execute('select * from v').one() - assert "{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}" == str(results.v) + results = session.execute("select * from v").one() + assert ( + "{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}" + == str(results.v) + ) - results = session.execute('select * from w').one() - assert "typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])" == str(results.v) + results = session.execute("select * from w").one() + assert ( + "typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])" + == str(results.v) + ) finally: session.cluster.shutdown() @@ -970,31 +1204,37 @@ def read_inserts_at_level(self, proto_ver): def run_inserts_at_version(self, proto_ver): session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: - p = session.prepare('insert into t (k, v) values (?, ?)') + p = session.prepare("insert into t (k, v) values (?, ?)") session.execute(p, (0, [{1, 2}, {3, 5}])) - p = session.prepare('insert into u (k, v) values (?, ?)') + p = session.prepare("insert into u (k, v) values (?, ?)") session.execute(p, (0, {(1, 2), (3, 5)})) - p = session.prepare('insert into v (k, v, v1) values (?, ?, ?)') - session.execute(p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, 'four'))) + p = session.prepare("insert into v (k, v, v1) values (?, ?, ?)") + session.execute( + p, (0, {(1, 2): [1, 2, 3], (3, 5): [4, 5, 6]}, (123, "four")) + ) - p = session.prepare('insert into w (k, v) values (?, ?)') + p = session.prepare("insert into w (k, v) values (?, ?)") session.execute(p, (0, ({1: [1, 2, 3], 2: [4, 5, 6]}, [7, 8, 9]))) finally: session.cluster.shutdown() + @requires_vector_type class TypeTestsVector(BasicSharedKeyspaceUnitTestCase): - def _get_first_j(self, rs): rows = rs.all() assert len(rows) == 1 return rows[0].j def _get_row_simple(self, idx, table_name): - rs = self.session.execute("select j from {0}.{1} where i = {2}".format(self.keyspace_name, table_name, idx)) + rs = self.session.execute( + "select j from {0}.{1} where i = {2}".format( + self.keyspace_name, table_name, idx + ) + ) return self._get_first_j(rs) def _get_row_prepared(self, idx, table_name): @@ -1003,9 +1243,13 @@ def _get_row_prepared(self, idx, table_name): rs = self.session.execute(ps, [idx]) return self._get_first_j(rs) - def _round_trip_test(self, subtype, subtype_fn, test_fn, use_positional_parameters=True): + def _round_trip_test( + self, subtype, subtype_fn, test_fn, use_positional_parameters=True + ): - table_name = subtype.replace("<","A").replace(">", "B").replace(",", "C") + "isH" + table_name = ( + subtype.replace("<", "A").replace(">", "B").replace(",", "C") + "isH" + ) def random_subtype_vector(): return [subtype_fn() for _ in range(3)] @@ -1016,20 +1260,28 @@ def random_subtype_vector(): self.session.execute(ddl) if use_positional_parameters: - cql = "insert into {0}.{1} (i,j) values (%s,%s)".format(self.keyspace_name, table_name) + cql = "insert into {0}.{1} (i,j) values (%s,%s)".format( + self.keyspace_name, table_name + ) expected1 = random_subtype_vector() - data1 = {1:random_subtype_vector(), 2:expected1, 3:random_subtype_vector()} - for k,v in data1.items(): + data1 = { + 1: random_subtype_vector(), + 2: expected1, + 3: random_subtype_vector(), + } + for k, v in data1.items(): # Attempt a set of inserts using the driver's support for positional params - self.session.execute(cql, (k,v)) + self.session.execute(cql, (k, v)) - cql = "insert into {0}.{1} (i,j) values (?,?)".format(self.keyspace_name, table_name) + cql = "insert into {0}.{1} (i,j) values (?,?)".format( + self.keyspace_name, table_name + ) expected2 = random_subtype_vector() ps = self.session.prepare(cql) - data2 = {4:random_subtype_vector(), 5:expected2, 6:random_subtype_vector()} - for k,v in data2.items(): + data2 = {4: random_subtype_vector(), 5: expected2, 6: random_subtype_vector()} + for k, v in data2.items(): # Add some additional rows via prepared statements - self.session.execute(ps, [k,v]) + self.session.execute(ps, [k, v]) # Use prepared queries to gather data from the rows we added via simple queries and vice versa if use_positional_parameters: @@ -1042,36 +1294,52 @@ def random_subtype_vector(): test_fn(observed2[idx], expected2[idx]) def test_round_trip_integers(self): - self._round_trip_test("int", partial(random.randint, 0, 2 ** 31), assertEqual) - self._round_trip_test("bigint", partial(random.randint, 0, 2 ** 63), assertEqual) - self._round_trip_test("smallint", partial(random.randint, 0, 2 ** 15), assertEqual) - self._round_trip_test("tinyint", partial(random.randint, 0, (2 ** 7) - 1), assertEqual) - self._round_trip_test("varint", partial(random.randint, 0, 2 ** 63), assertEqual) + self._round_trip_test("int", partial(random.randint, 0, 2**31), assertEqual) + self._round_trip_test("bigint", partial(random.randint, 0, 2**63), assertEqual) + self._round_trip_test( + "smallint", partial(random.randint, 0, 2**15), assertEqual + ) + self._round_trip_test( + "tinyint", partial(random.randint, 0, (2**7) - 1), assertEqual + ) + self._round_trip_test("varint", partial(random.randint, 0, 2**63), assertEqual) def test_round_trip_floating_point(self): _almost_equal_test_fn = partial(pytest.approx, abs=1e-5) + def _random_decimal(): return Decimal(random.uniform(0.0, 100.0)) # Max value here isn't really connected to max value for floating point nums in IEEE 754... it's used here # mainly as a convenient benchmark - self._round_trip_test("float", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) - self._round_trip_test("double", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) + self._round_trip_test( + "float", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn + ) + self._round_trip_test( + "double", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn + ) self._round_trip_test("decimal", _random_decimal, _almost_equal_test_fn) def test_round_trip_text(self): def _random_string(): - return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(24)) + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(24) + ) self._round_trip_test("ascii", _random_string, assertEqual) self._round_trip_test("text", _random_string, assertEqual) def test_round_trip_date_and_time(self): _almost_equal_test_fn = partial(pytest.approx, abs=timedelta(seconds=1)) + def _random_datetime(): - return datetime.today() - timedelta(hours=random.randint(0,18), days=random.randint(1,1000)) + return datetime.today() - timedelta( + hours=random.randint(0, 18), days=random.randint(1, 1000) + ) + def _random_date(): return _random_datetime().date() + def _random_time(): return _random_datetime().time() @@ -1085,11 +1353,16 @@ def test_round_trip_uuid(self): def test_round_trip_miscellany(self): def _random_bytes(): - return random.getrandbits(32).to_bytes(4,'big') + return random.getrandbits(32).to_bytes(4, "big") + def _random_boolean(): return random.choice([True, False]) + def _random_duration(): - return Duration(random.randint(0,11), random.randint(0,11), random.randint(0,10000)) + return Duration( + random.randint(0, 11), random.randint(0, 11), random.randint(0, 10000) + ) + def _random_inet(): return socket.inet_ntoa(_random_bytes()) @@ -1100,11 +1373,13 @@ def _random_inet(): def test_round_trip_collections(self): def _random_seq(): - return [random.randint(0,100000) for _ in range(8)] + return [random.randint(0, 100000) for _ in range(8)] + def _random_set(): return set(_random_seq()) + def _random_map(): - return {k:v for (k,v) in zip(_random_seq(), _random_seq())} + return {k: v for (k, v) in zip(_random_seq(), _random_seq())} # Goal here is to test collections of both fixed and variable size subtypes self._round_trip_test("list", _random_seq, assertEqual) @@ -1118,44 +1393,76 @@ def _random_map(): def test_round_trip_vector_of_vectors(self): def _random_vector(): - return [random.randint(0,100000) for _ in range(2)] + return [random.randint(0, 100000) for _ in range(2)] self._round_trip_test("vector", _random_vector, assertEqual) self._round_trip_test("vector", _random_vector, assertEqual) def test_round_trip_tuples(self): def _random_tuple(): - return (random.randint(0,100000),random.randint(0,100000)) + return (random.randint(0, 100000), random.randint(0, 100000)) # Unfortunately we can't use positional parameters when inserting tuples because the driver will try to encode # them as lists before sending them to the server... and that confuses the parsing logic. - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) + self._round_trip_test( + "tuple", + _random_tuple, + assertEqual, + use_positional_parameters=False, + ) def test_round_trip_udts(self): def _udt_equal_test_fn(udt1, udt2): assert udt1.a == udt2.a assert udt1.b == udt2.b - self.session.execute("create type {}.fixed_type (a int, b int)".format(self.keyspace_name)) - self.session.execute("create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name)) - self.session.execute("create type {}.mixed_type_two (a varint, b int)".format(self.keyspace_name)) - self.session.execute("create type {}.var_type (a varint, b varint)".format(self.keyspace_name)) + self.session.execute( + "create type {}.fixed_type (a int, b int)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.mixed_type_two (a varint, b int)".format(self.keyspace_name) + ) + self.session.execute( + "create type {}.var_type (a varint, b varint)".format(self.keyspace_name) + ) class GeneralUDT: def __init__(self, a, b): self.a = a self.b = b - self.cluster.register_user_type(self.keyspace_name,'fixed_type', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'mixed_type_one', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'mixed_type_two', GeneralUDT) - self.cluster.register_user_type(self.keyspace_name,'var_type', GeneralUDT) + self.cluster.register_user_type(self.keyspace_name, "fixed_type", GeneralUDT) + self.cluster.register_user_type( + self.keyspace_name, "mixed_type_one", GeneralUDT + ) + self.cluster.register_user_type( + self.keyspace_name, "mixed_type_two", GeneralUDT + ) + self.cluster.register_user_type(self.keyspace_name, "var_type", GeneralUDT) def _random_udt(): - return GeneralUDT(random.randint(0,100000),random.randint(0,100000)) + return GeneralUDT(random.randint(0, 100000), random.randint(0, 100000)) self._round_trip_test("fixed_type", _random_udt, _udt_equal_test_fn) self._round_trip_test("mixed_type_one", _random_udt, _udt_equal_test_fn) diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index e4b415ac69..6a4aef2df9 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -66,9 +66,9 @@ (b'', 'InetAddressType', None), (b'A46\xa9', 'InetAddressType', '65.52.54.169'), (b'*\x00\x13(\xe1\x02\xcc\xc0\x00\x00\x00\x00\x00\x00\x01"', 'InetAddressType', '2a00:1328:e102:ccc0::122'), - (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', u'\u307e\u3057\u3066'), - (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', u'\u307e\u3057\u3066' * 1000), - (b'', 'UTF8Type', u''), + (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6', 'UTF8Type', '\u307e\u3057\u3066'), + (b'\xe3\x81\xbe\xe3\x81\x97\xe3\x81\xa6' * 1000, 'UTF8Type', '\u307e\u3057\u3066' * 1000), + (b'', 'UTF8Type', ''), (b'\xff' * 16, 'UUIDType', UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), (b'I\x15~\xfc\xef<\x9d\xe3\x16\x98\xaf\x80\x1f\xb4\x0b*', 'UUIDType', UUID('49157efc-ef3c-9de3-1698-af801fb40b2a')), (b'', 'UUIDType', None), @@ -89,9 +89,9 @@ ) ordered_map_value = OrderedMapSerializedKey(UTF8Type, 3) -ordered_map_value._insert(u'\u307fbob', 199) -ordered_map_value._insert(u'', -1) -ordered_map_value._insert(u'\\', 0) +ordered_map_value._insert('\u307fbob', 199) +ordered_map_value._insert('', -1) +ordered_map_value._insert('\\', 0) # these following entries work for me right now, but they're dependent on # vagaries of internal python ordering for unordered types diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index dcbb840447..9689463729 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -671,27 +671,27 @@ def test_table_name(self): def test_column_name_single_partition(self): tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + cm = ColumnMetadata(tm, self.name, 'int') tm.columns[cm.name] = cm tm.partition_key.append(cm) tm.export_as_string() def test_column_name_single_partition_single_clustering(self): tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + cm = ColumnMetadata(tm, self.name, 'int') tm.columns[cm.name] = cm tm.partition_key.append(cm) - cm = ColumnMetadata(tm, self.name + 'x', u'int') + cm = ColumnMetadata(tm, self.name + 'x', 'int') tm.columns[cm.name] = cm tm.clustering_key.append(cm) tm.export_as_string() def test_column_name_multiple_partition(self): tm = TableMetadata('ks', 'table') - cm = ColumnMetadata(tm, self.name, u'int') + cm = ColumnMetadata(tm, self.name, 'int') tm.columns[cm.name] = cm tm.partition_key.append(cm) - cm = ColumnMetadata(tm, self.name + 'x', u'int') + cm = ColumnMetadata(tm, self.name + 'x', 'int') tm.columns[cm.name] = cm tm.partition_key.append(cm) tm.export_as_string() @@ -707,20 +707,20 @@ def test_index(self): def test_function(self): fm = Function(keyspace=self.name, name=self.name, - argument_types=(u'int', u'int'), - argument_names=(u'x', u'y'), - return_type=u'int', language=u'language', + argument_types=('int', 'int'), + argument_names=('x', 'y'), + return_type='int', language='language', body=self.name, called_on_null_input=False, deterministic=True, - monotonic=False, monotonic_on=(u'x',)) + monotonic=False, monotonic_on=('x',)) fm.export_as_string() def test_aggregate(self): - am = Aggregate(self.name, self.name, (u'text',), self.name, u'text', self.name, self.name, u'text', True) + am = Aggregate(self.name, self.name, ('text',), self.name, 'text', self.name, self.name, 'text', True) am.export_as_string() def test_user_type(self): - um = UserType(self.name, self.name, [self.name, self.name], [u'int', u'text']) + um = UserType(self.name, self.name, [self.name, self.name], ['int', 'text']) um.export_as_string() @@ -729,10 +729,10 @@ class FunctionToCQLTests(unittest.TestCase): base_vars = { 'keyspace': 'ks_name', 'name': 'function_name', - 'argument_types': (u'int', u'int'), - 'argument_names': (u'x', u'y'), - 'return_type': u'int', - 'language': u'language', + 'argument_types': ('int', 'int'), + 'argument_names': ('x', 'y'), + 'return_type': 'int', + 'language': 'language', 'body': 'body', 'called_on_null_input': False, 'deterministic': True, @@ -785,10 +785,10 @@ class AggregateToCQLTests(unittest.TestCase): base_vars = { 'keyspace': 'ks_name', 'name': 'function_name', - 'argument_types': (u'int', u'int'), + 'argument_types': ('int', 'int'), 'state_func': 'funcname', - 'state_type': u'int', - 'return_type': u'int', + 'state_type': 'int', + 'return_type': 'int', 'final_func': None, 'initial_condition': '0', 'deterministic': True diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py index 156bbd5f30..c93aa6501c 100644 --- a/tests/unit/test_orderedmap.py +++ b/tests/unit/test_orderedmap.py @@ -19,27 +19,28 @@ from tests.util import assertListEqual import pytest + class OrderedMapTest(unittest.TestCase): def test_init(self): - a = OrderedMap(zip(['one', 'three', 'two'], [1, 3, 2])) - b = OrderedMap([('one', 1), ('three', 3), ('two', 2)]) + a = OrderedMap(zip(["one", "three", "two"], [1, 3, 2])) + b = OrderedMap([("one", 1), ("three", 3), ("two", 2)]) c = OrderedMap(a) - builtin = {'one': 1, 'two': 2, 'three': 3} + builtin = {"one": 1, "two": 2, "three": 3} assert a == b assert a == c assert a == builtin assert OrderedMap([(1, 1), (1, 2)]) == {1: 2} - d = OrderedMap({'': 3}, key1='v1', key2='v2') - assert d[''] == 3 - assert d['key1'] == 'v1' - assert d['key2'] == 'v2' + d = OrderedMap({"": 3}, key1="v1", key2="v2") + assert d[""] == 3 + assert d["key1"] == "v1" + assert d["key2"] == "v2" with pytest.raises(TypeError): - OrderedMap('too', 'many', 'args') + OrderedMap("too", "many", "args") def test_contains(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] om = OrderedMap() @@ -49,45 +50,45 @@ def test_contains(self): assert k in om assert not k not in om - assert 'notthere' not in om - assert not 'notthere' in om + assert "notthere" not in om + assert not "notthere" in om def test_keys(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] om = OrderedMap(zip(keys, range(len(keys)))) assertListEqual(list(om.keys()), keys) def test_values(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] values = list(range(len(keys))) om = OrderedMap(zip(keys, values)) assertListEqual(list(om.values()), values) def test_items(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] items = list(zip(keys, range(len(keys)))) om = OrderedMap(items) assertListEqual(list(om.items()), items) def test_get(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] om = OrderedMap(zip(keys, range(len(keys)))) for v, k in enumerate(keys): assert om.get(k) == v - assert om.get('notthere', 'default') == 'default' - assert om.get('notthere') is None + assert om.get("notthere", "default") == "default" + assert om.get("notthere") is None def test_equal(self): - d1 = {'one': 1} - d12 = {'one': 1, 'two': 2} - om1 = OrderedMap({'one': 1}) - om12 = OrderedMap([('one', 1), ('two', 2)]) - om21 = OrderedMap([('two', 2), ('one', 1)]) + d1 = {"one": 1} + d12 = {"one": 1, "two": 2} + om1 = OrderedMap({"one": 1}) + om12 = OrderedMap([("one", 1), ("two", 2)]) + om21 = OrderedMap([("two", 2), ("one", 1)]) assert om1 == d1 assert om12 == d12 @@ -99,20 +100,20 @@ def test_equal(self): assert om12 != d1 assert om1 != EMPTY - assert not OrderedMap([('three', 3), ('four', 4)]) == d12 + assert not OrderedMap([("three", 3), ("four", 4)]) == d12 def test_getitem(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] om = OrderedMap(zip(keys, range(len(keys)))) for v, k in enumerate(keys): assert om[k] == v with pytest.raises(KeyError): - om['notthere'] + om["notthere"] def test_iter(self): - keys = ['first', 'middle', 'last'] + keys = ["first", "middle", "last"] values = list(range(len(keys))) items = list(zip(keys, values)) om = OrderedMap(items) @@ -131,17 +132,24 @@ def test_len(self): assert len(OrderedMap([(1, 1)])) == 1 def test_mutable_keys(self): - d = {'1': 1} + d = {"1": 1} s = set([1, 2, 3]) - om = OrderedMap([(d, 'dict'), (s, 'set')]) + om = OrderedMap([(d, "dict"), (s, "set")]) def test_strings(self): # changes in 3.x - d = {'map': 'inner'} + d = {"map": "inner"} s = set([1, 2, 3]) - assert repr(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])) == "OrderedMap([('two', 2), ('one', 1), (%r, 'value'), (%r, 'another')])" % (d, s) + assert repr( + OrderedMap([("two", 2), ("one", 1), (d, "value"), (s, "another")]) + ) == "OrderedMap([('two', 2), ('one', 1), (%r, 'value'), (%r, 'another')])" % ( + d, + s, + ) - assert str(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])) == "{'two': 2, 'one': 1, %r: 'value', %r: 'another'}" % (d, s) + assert str( + OrderedMap([("two", 2), ("one", 1), (d, "value"), (s, "another")]) + ) == "{'two': 2, 'one': 1, %r: 'value', %r: 'another'}" % (d, s) def test_popitem(self): item = (1, 2) @@ -171,16 +179,20 @@ def test_init(self): assert om == {} def test_normalized_lookup(self): - key_type = lookup_casstype('MapType(UTF8Type, Int32Type)') + key_type = lookup_casstype("MapType(UTF8Type, Int32Type)") protocol_version = 3 om = OrderedMapSerializedKey(key_type, protocol_version) - key_ascii = {'one': 1} - key_unicode = {u'two': 2} - om._insert_unchecked(key_ascii, key_type.serialize(key_ascii, protocol_version), object()) - om._insert_unchecked(key_unicode, key_type.serialize(key_unicode, protocol_version), object()) + key_one = {"one": 1} + key_two = {"two": 2} + om._insert_unchecked( + key_one, key_type.serialize(key_one, protocol_version), object() + ) + om._insert_unchecked( + key_two, key_type.serialize(key_two, protocol_version), object() + ) # type lookup is normalized by key_type # PYTHON-231 - assert om[{'one': 1}] is om[{u'one': 1}] - assert om[{'two': 2}] is om[{u'two': 2}] - assert om[{'one': 1}] is not om[{'two': 2}] + assert om[{"one": 1}] is om[{"one": 1}] + assert om[{"two": 2}] is om[{"two": 2}] + assert om[{"one": 1}] is not om[{"two": 2}] diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 7a8c584f75..39bd750ca5 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -290,14 +290,14 @@ def test_collection_null_support(self): def test_write_read_string(self): with tempfile.TemporaryFile() as f: - value = u'test' + value = 'test' write_string(f, value) f.seek(0) assert read_string(f) == value def test_write_read_longstring(self): with tempfile.TemporaryFile() as f: - value = u'test' + value = 'test' write_longstring(f, value) f.seek(0) assert read_longstring(f) == value @@ -323,7 +323,7 @@ def test_write_read_inet(self): assert read_inet(f) == value def test_cql_quote(self): - assert cql_quote(u'test') == "'test'" + assert cql_quote('test') == "'test'" assert cql_quote('test') == "'test'" assert cql_quote(0) == '0' From 59b6813ff458089eca30e192d7274db233680480 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Wed, 4 Mar 2026 19:24:55 +0200 Subject: [PATCH 18/18] Update stale Python 2 comments and docstrings Several comments and docstrings still referenced Python 2 concepts that no longer apply now that the driver requires Python 3.9+: - encoder.py: Updated cql_encode_unicode() docstring to note it is unused since Python 2 removal (str is always unicode on Python 3). Also fixed the method body: it was calling val.encode('utf-8') which on Python 3 converts str to bytes, producing wrong output. Now it passes val directly to cql_quote. - metadata.py: Changed 'will always be a unicode' to 'will always be a str' (line 2155). Updated unhexlify comment to say 'str input' instead of 'unicode input' and fixed typo 'everythin' (line 2350). - graphson.py: Removed '(PY2)'/'(PY3)' qualifiers from the type mapping table. Updated 'long' to 'int' for varint, 'str (unicode)' to 'str' for inet, removed 'buffer (PY2)' from blob entries. - util.py: Updated comment on _positional_rename_invalid_identifiers to remove stale 'Python 2.6' reference. - asyncorereactor.py: Removed stale 'TODO: Remove when Python 2 support is removed' since Python 2 support has been removed. The guard itself is still needed for interpreter shutdown scenarios. All 608 unit tests pass. --- cassandra/datastax/graph/graphson.py | 558 ++++++++++++++++----------- cassandra/encoder.py | 85 ++-- cassandra/io/asyncorereactor.py | 6 +- cassandra/metadata.py | 4 +- cassandra/util.py | 3 +- 5 files changed, 391 insertions(+), 265 deletions(-) diff --git a/cassandra/datastax/graph/graphson.py b/cassandra/datastax/graph/graphson.py index 335c7f7825..7b0553e66c 100644 --- a/cassandra/datastax/graph/graphson.py +++ b/cassandra/datastax/graph/graphson.py @@ -31,15 +31,45 @@ from cassandra.util import Polygon, Point, LineString, Duration from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path, T -__all__ = ['GraphSON1Serializer', 'GraphSON1Deserializer', 'GraphSON1TypeDeserializer', - 'GraphSON2Serializer', 'GraphSON2Deserializer', 'GraphSON2Reader', - 'GraphSON3Serializer', 'GraphSON3Deserializer', 'GraphSON3Reader', - 'to_bigint', 'to_int', 'to_double', 'to_float', 'to_smallint', - 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO', - 'FloatTypeIO', 'UUIDTypeIO', 'BigDecimalTypeIO', 'DurationTypeIO', 'InetTypeIO', - 'InstantTypeIO', 'LocalDateTypeIO', 'LocalTimeTypeIO', 'Int64TypeIO', 'BigIntegerTypeIO', - 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO', - 'GraphSON3Serializer', 'GraphSON3Deserializer', 'UserTypeIO', 'TypeWrapperTypeIO'] +__all__ = [ + "GraphSON1Serializer", + "GraphSON1Deserializer", + "GraphSON1TypeDeserializer", + "GraphSON2Serializer", + "GraphSON2Deserializer", + "GraphSON2Reader", + "GraphSON3Serializer", + "GraphSON3Deserializer", + "GraphSON3Reader", + "to_bigint", + "to_int", + "to_double", + "to_float", + "to_smallint", + "BooleanTypeIO", + "Int16TypeIO", + "Int32TypeIO", + "DoubleTypeIO", + "FloatTypeIO", + "UUIDTypeIO", + "BigDecimalTypeIO", + "DurationTypeIO", + "InetTypeIO", + "InstantTypeIO", + "LocalDateTypeIO", + "LocalTimeTypeIO", + "Int64TypeIO", + "BigIntegerTypeIO", + "LocalDateTypeIO", + "PolygonTypeIO", + "PointTypeIO", + "LineStringTypeIO", + "BlobTypeIO", + "GraphSON3Serializer", + "GraphSON3Deserializer", + "UserTypeIO", + "TypeWrapperTypeIO", +] """ Supported types: @@ -56,18 +86,18 @@ bigdecimal | gx:BigDecimal | gx:BigDecimal | Decimal duration | gx:Duration | N/A | timedelta (Classic graph only) DSE Duration | N/A | dse:Duration | Duration (Core graph only) -inet | gx:InetAddress | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3) +inet | gx:InetAddress | gx:InetAddress | str, IPV4Address/IPV6Address timestamp | gx:Instant | gx:Instant | datetime.datetime date | gx:LocalDate | gx:LocalDate | datetime.date time | gx:LocalTime | gx:LocalTime | datetime.time smallint | gx:Int16 | gx:Int16 | int -varint | gx:BigInteger | gx:BigInteger | long +varint | gx:BigInteger | gx:BigInteger | int date | gx:LocalDate | gx:LocalDate | Date polygon | dse:Polygon | dse:Polygon | Polygon point | dse:Point | dse:Point | Point linestring | dse:Linestring | dse:LineString | LineString -blob | dse:Blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) -blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +blob | dse:Blob | dse:Blob | bytearray, memoryview, bytes +blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, memoryview, bytes list | N/A | g:List | list (Core graph only) map | N/A | g:Map | dict (Core graph only) set | N/A | g:Set | set or list (Core graph only) @@ -76,8 +106,8 @@ udt | N/A | dse:UDT | class or namedtuple (Core graph only) """ -MAX_INT32 = 2 ** 32 - 1 -MIN_INT32 = -2 ** 31 +MAX_INT32 = 2**32 - 1 +MIN_INT32 = -(2**31) log = logging.getLogger(__name__) @@ -93,13 +123,13 @@ def graphson_type(cls): class GraphSONTypeIO(object, metaclass=_GraphSONTypeType): """Represent a serializable GraphSON type""" - prefix = 'g' + prefix = "g" graphson_base_type = None cql_type = None @classmethod def definition(cls, value, writer=None): - return {'cqlType': cls.cql_type} + return {"cqlType": cls.cql_type} @classmethod def serialize(cls, value, writer=None): @@ -115,12 +145,12 @@ def get_specialized_serializer(cls, value): class TextTypeIO(GraphSONTypeIO): - cql_type = 'text' + cql_type = "text" class BooleanTypeIO(GraphSONTypeIO): graphson_base_type = None - cql_type = 'boolean' + cql_type = "boolean" @classmethod def serialize(cls, value, writer=None): @@ -128,7 +158,6 @@ def serialize(cls, value, writer=None): class IntegerTypeIO(GraphSONTypeIO): - @classmethod def serialize(cls, value, writer=None): return value @@ -142,19 +171,19 @@ def get_specialized_serializer(cls, value): class Int16TypeIO(IntegerTypeIO): - prefix = 'gx' - graphson_base_type = 'Int16' - cql_type = 'smallint' + prefix = "gx" + graphson_base_type = "Int16" + cql_type = "smallint" class Int32TypeIO(IntegerTypeIO): - graphson_base_type = 'Int32' - cql_type = 'int' + graphson_base_type = "Int32" + cql_type = "int" class Int64TypeIO(IntegerTypeIO): - graphson_base_type = 'Int64' - cql_type = 'bigint' + graphson_base_type = "Int64" + cql_type = "bigint" @classmethod def deserialize(cls, value, reader=None): @@ -162,8 +191,8 @@ def deserialize(cls, value, reader=None): class FloatTypeIO(GraphSONTypeIO): - graphson_base_type = 'Float' - cql_type = 'float' + graphson_base_type = "Float" + cql_type = "float" @classmethod def serialize(cls, value, writer=None): @@ -175,21 +204,21 @@ def deserialize(cls, value, reader=None): class DoubleTypeIO(FloatTypeIO): - graphson_base_type = 'Double' - cql_type = 'double' + graphson_base_type = "Double" + cql_type = "double" class BigIntegerTypeIO(IntegerTypeIO): - prefix = 'gx' - graphson_base_type = 'BigInteger' + prefix = "gx" + graphson_base_type = "BigInteger" class LocalDateTypeIO(GraphSONTypeIO): - FORMAT = '%Y-%m-%d' + FORMAT = "%Y-%m-%d" - prefix = 'gx' - graphson_base_type = 'LocalDate' - cql_type = 'date' + prefix = "gx" + graphson_base_type = "LocalDate" + cql_type = "date" @classmethod def serialize(cls, value, writer=None): @@ -205,14 +234,16 @@ def deserialize(cls, value, reader=None): class InstantTypeIO(GraphSONTypeIO): - prefix = 'gx' - graphson_base_type = 'Instant' - cql_type = 'timestamp' + prefix = "gx" + graphson_base_type = "Instant" + cql_type = "timestamp" @classmethod def serialize(cls, value, writer=None): if isinstance(value, datetime.datetime): - value = datetime.datetime(*value.utctimetuple()[:6]).replace(microsecond=value.microsecond) + value = datetime.datetime(*value.utctimetuple()[:6]).replace( + microsecond=value.microsecond + ) else: value = datetime.datetime.combine(value, datetime.datetime.min.time()) @@ -221,22 +252,18 @@ def serialize(cls, value, writer=None): @classmethod def deserialize(cls, value, reader=None): try: - d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ') + d = datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ") except ValueError: - d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%SZ') + d = datetime.datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ") return d class LocalTimeTypeIO(GraphSONTypeIO): - FORMATS = [ - '%H:%M', - '%H:%M:%S', - '%H:%M:%S.%f' - ] + FORMATS = ["%H:%M", "%H:%M:%S", "%H:%M:%S.%f"] - prefix = 'gx' - graphson_base_type = 'LocalTime' - cql_type = 'time' + prefix = "gx" + graphson_base_type = "LocalTime" + cql_type = "time" @classmethod def serialize(cls, value, writer=None): @@ -253,20 +280,20 @@ def deserialize(cls, value, reader=None): continue if dt is None: - raise ValueError('Unable to decode LocalTime: {0}'.format(value)) + raise ValueError("Unable to decode LocalTime: {0}".format(value)) return dt.time() class BlobTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'Blob' - cql_type = 'blob' + prefix = "dse" + graphson_base_type = "Blob" + cql_type = "blob" @classmethod def serialize(cls, value, writer=None): value = base64.b64encode(value) - value = value.decode('utf-8') + value = value.decode("utf-8") return value @classmethod @@ -275,13 +302,13 @@ def deserialize(cls, value, reader=None): class ByteBufferTypeIO(BlobTypeIO): - prefix = 'gx' - graphson_base_type = 'ByteBuffer' + prefix = "gx" + graphson_base_type = "ByteBuffer" class UUIDTypeIO(GraphSONTypeIO): - graphson_base_type = 'UUID' - cql_type = 'uuid' + graphson_base_type = "UUID" + cql_type = "uuid" @classmethod def deserialize(cls, value, reader=None): @@ -289,9 +316,9 @@ def deserialize(cls, value, reader=None): class BigDecimalTypeIO(GraphSONTypeIO): - prefix = 'gx' - graphson_base_type = 'BigDecimal' - cql_type = 'bigdecimal' + prefix = "gx" + graphson_base_type = "BigDecimal" + cql_type = "bigdecimal" @classmethod def deserialize(cls, value, reader=None): @@ -299,16 +326,19 @@ def deserialize(cls, value, reader=None): class DurationTypeIO(GraphSONTypeIO): - prefix = 'gx' - graphson_base_type = 'Duration' - cql_type = 'duration' + prefix = "gx" + graphson_base_type = "Duration" + cql_type = "duration" - _duration_regex = re.compile(r""" + _duration_regex = re.compile( + r""" ^P((?P\d+)D)? T((?P\d+)H)? ((?P\d+)M)? ((?P[0-9.]+)S)?$ - """, re.VERBOSE) + """, + re.VERBOSE, + ) _duration_format = "P{days}DT{hours}H{minutes}M{seconds}S" _seconds_in_minute = 60 @@ -324,48 +354,51 @@ def serialize(cls, value, writer=None): total_seconds += value.microseconds / 1e6 return cls._duration_format.format( - days=int(days), hours=int(hours), minutes=int(minutes), seconds=total_seconds + days=int(days), + hours=int(hours), + minutes=int(minutes), + seconds=total_seconds, ) @classmethod def deserialize(cls, value, reader=None): duration = cls._duration_regex.match(value) if duration is None: - raise ValueError('Invalid duration: {0}'.format(value)) + raise ValueError("Invalid duration: {0}".format(value)) - duration = {k: float(v) if v is not None else 0 - for k, v in duration.groupdict().items()} - return datetime.timedelta(days=duration['days'], hours=duration['hours'], - minutes=duration['minutes'], seconds=duration['seconds']) + duration = { + k: float(v) if v is not None else 0 for k, v in duration.groupdict().items() + } + return datetime.timedelta( + days=duration["days"], + hours=duration["hours"], + minutes=duration["minutes"], + seconds=duration["seconds"], + ) class DseDurationTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'Duration' - cql_type = 'duration' + prefix = "dse" + graphson_base_type = "Duration" + cql_type = "duration" @classmethod def serialize(cls, value, writer=None): - return { - 'months': value.months, - 'days': value.days, - 'nanos': value.nanoseconds - } + return {"months": value.months, "days": value.days, "nanos": value.nanoseconds} @classmethod def deserialize(cls, value, reader=None): return Duration( - reader.deserialize(value['months']), - reader.deserialize(value['days']), - reader.deserialize(value['nanos']) + reader.deserialize(value["months"]), + reader.deserialize(value["days"]), + reader.deserialize(value["nanos"]), ) class TypeWrapperTypeIO(GraphSONTypeIO): - @classmethod def definition(cls, value, writer=None): - return {'cqlType': value.type_io.cql_type} + return {"cqlType": value.type_io.cql_type} @classmethod def serialize(cls, value, writer=None): @@ -377,8 +410,8 @@ def deserialize(cls, value, reader=None): class PointTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'Point' + prefix = "dse" + graphson_base_type = "Point" cql_type = "org.apache.cassandra.db.marshal.PointType" @classmethod @@ -387,8 +420,8 @@ def deserialize(cls, value, reader=None): class LineStringTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'LineString' + prefix = "dse" + graphson_base_type = "LineString" cql_type = "org.apache.cassandra.db.marshal.LineStringType" @classmethod @@ -397,8 +430,8 @@ def deserialize(cls, value, reader=None): class PolygonTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'Polygon' + prefix = "dse" + graphson_base_type = "Polygon" cql_type = "org.apache.cassandra.db.marshal.PolygonType" @classmethod @@ -407,62 +440,70 @@ def deserialize(cls, value, reader=None): class InetTypeIO(GraphSONTypeIO): - prefix = 'gx' - graphson_base_type = 'InetAddress' - cql_type = 'inet' + prefix = "gx" + graphson_base_type = "InetAddress" + cql_type = "inet" class VertexTypeIO(GraphSONTypeIO): - graphson_base_type = 'Vertex' + graphson_base_type = "Vertex" @classmethod def deserialize(cls, value, reader=None): - vertex = Vertex(id=reader.deserialize(value["id"]), - label=value["label"] if "label" in value else "vertex", - type='vertex', - properties={}) + vertex = Vertex( + id=reader.deserialize(value["id"]), + label=value["label"] if "label" in value else "vertex", + type="vertex", + properties={}, + ) # avoid the properties processing in Vertex.__init__ - vertex.properties = reader.deserialize(value.get('properties', {})) + vertex.properties = reader.deserialize(value.get("properties", {})) return vertex class VertexPropertyTypeIO(GraphSONTypeIO): - graphson_base_type = 'VertexProperty' + graphson_base_type = "VertexProperty" @classmethod def deserialize(cls, value, reader=None): - return VertexProperty(label=value['label'], - value=reader.deserialize(value["value"]), - properties=reader.deserialize(value.get('properties', {}))) + return VertexProperty( + label=value["label"], + value=reader.deserialize(value["value"]), + properties=reader.deserialize(value.get("properties", {})), + ) class EdgeTypeIO(GraphSONTypeIO): - graphson_base_type = 'Edge' + graphson_base_type = "Edge" @classmethod def deserialize(cls, value, reader=None): - in_vertex = Vertex(id=reader.deserialize(value["inV"]), - label=value['inVLabel'], - type='vertex', - properties={}) - out_vertex = Vertex(id=reader.deserialize(value["outV"]), - label=value['outVLabel'], - type='vertex', - properties={}) + in_vertex = Vertex( + id=reader.deserialize(value["inV"]), + label=value["inVLabel"], + type="vertex", + properties={}, + ) + out_vertex = Vertex( + id=reader.deserialize(value["outV"]), + label=value["outVLabel"], + type="vertex", + properties={}, + ) return Edge( id=reader.deserialize(value["id"]), label=value["label"] if "label" in value else "vertex", - type='edge', + type="edge", properties=reader.deserialize(value.get("properties", {})), inV=in_vertex, - inVLabel=value['inVLabel'], + inVLabel=value["inVLabel"], outV=out_vertex, - outVLabel=value['outVLabel'] + outVLabel=value["outVLabel"], ) class PropertyTypeIO(GraphSONTypeIO): - graphson_base_type = 'Property' + graphson_base_type = "Property" @classmethod def deserialize(cls, value, reader=None): @@ -470,19 +511,19 @@ def deserialize(cls, value, reader=None): class PathTypeIO(GraphSONTypeIO): - graphson_base_type = 'Path' + graphson_base_type = "Path" @classmethod def deserialize(cls, value, reader=None): - labels = [set(label) for label in reader.deserialize(value['labels'])] - objects = [obj for obj in reader.deserialize(value['objects'])] + labels = [set(label) for label in reader.deserialize(value["labels"])] + objects = [obj for obj in reader.deserialize(value["objects"])] p = Path(labels, []) p.objects = objects # avoid the object processing in Path.__init__ return p class TraversalMetricsTypeIO(GraphSONTypeIO): - graphson_base_type = 'TraversalMetrics' + graphson_base_type = "TraversalMetrics" @classmethod def deserialize(cls, value, reader=None): @@ -490,7 +531,7 @@ def deserialize(cls, value, reader=None): class MetricsTypeIO(GraphSONTypeIO): - graphson_base_type = 'Metrics' + graphson_base_type = "Metrics" @classmethod def deserialize(cls, value, reader=None): @@ -512,17 +553,17 @@ def serialize(cls, value, writer=None): class MapTypeIO(GraphSONTypeIO): """In GraphSON3, dict has its own type""" - graphson_base_type = 'Map' - cql_type = 'map' + graphson_base_type = "Map" + cql_type = "map" @classmethod def definition(cls, value, writer=None): - out = OrderedDict([('cqlType', cls.cql_type)]) - out['definition'] = [] + out = OrderedDict([("cqlType", cls.cql_type)]) + out["definition"] = [] for k, v in value.items(): # we just need the first pair to write the def - out['definition'].append(writer.definition(k)) - out['definition'].append(writer.definition(v)) + out["definition"].append(writer.definition(k)) + out["definition"].append(writer.definition(v)) break return out @@ -540,8 +581,7 @@ def deserialize(cls, value, reader=None): out = {} a, b = itertools.tee(value) for key, val in zip( - itertools.islice(a, 0, None, 2), - itertools.islice(b, 1, None, 2) + itertools.islice(a, 0, None, 2), itertools.islice(b, 1, None, 2) ): out[reader.deserialize(key)] = reader.deserialize(val) return out @@ -550,15 +590,15 @@ def deserialize(cls, value, reader=None): class ListTypeIO(GraphSONTypeIO): """In GraphSON3, list has its own type""" - graphson_base_type = 'List' - cql_type = 'list' + graphson_base_type = "List" + cql_type = "list" @classmethod def definition(cls, value, writer=None): - out = OrderedDict([('cqlType', cls.cql_type)]) - out['definition'] = [] + out = OrderedDict([("cqlType", cls.cql_type)]) + out["definition"] = [] if value: - out['definition'].append(writer.definition(value[0])) + out["definition"].append(writer.definition(value[0])) return out @classmethod @@ -573,16 +613,16 @@ def deserialize(cls, value, reader=None): class SetTypeIO(GraphSONTypeIO): """In GraphSON3, set has its own type""" - graphson_base_type = 'Set' - cql_type = 'set' + graphson_base_type = "Set" + cql_type = "set" @classmethod def definition(cls, value, writer=None): - out = OrderedDict([('cqlType', cls.cql_type)]) - out['definition'] = [] + out = OrderedDict([("cqlType", cls.cql_type)]) + out["definition"] = [] for v in value: # we only take into account the first value for the definition - out['definition'].append(writer.definition(v)) + out["definition"].append(writer.definition(v)) break return out @@ -596,8 +636,10 @@ def deserialize(cls, value, reader=None): s = set(lst) if len(s) != len(lst): - log.warning("Coercing g:Set to list due to numerical values returned by Java. " - "See TINKERPOP-1844 for details.") + log.warning( + "Coercing g:Set to list due to numerical values returned by Java. " + "See TINKERPOP-1844 for details." + ) return lst return s @@ -612,8 +654,7 @@ def deserialize(cls, value, reader=None): a, b = itertools.tee(value) for val, bulk in zip( - itertools.islice(a, 0, None, 2), - itertools.islice(b, 1, None, 2) + itertools.islice(a, 0, None, 2), itertools.islice(b, 1, None, 2) ): val = reader.deserialize(val) bulk = reader.deserialize(bulk) @@ -624,58 +665,64 @@ def deserialize(cls, value, reader=None): class TupleTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'Tuple' - cql_type = 'tuple' + prefix = "dse" + graphson_base_type = "Tuple" + cql_type = "tuple" @classmethod def definition(cls, value, writer=None): out = OrderedDict() - out['cqlType'] = cls.cql_type + out["cqlType"] = cls.cql_type serializers = [writer.get_serializer(s) for s in value] - out['definition'] = [s.definition(v, writer) for v, s in zip(value, serializers)] + out["definition"] = [ + s.definition(v, writer) for v, s in zip(value, serializers) + ] return out @classmethod def serialize(cls, value, writer=None): out = cls.definition(value, writer) - out['value'] = [writer.serialize(v, writer) for v in value] + out["value"] = [writer.serialize(v, writer) for v in value] return out @classmethod def deserialize(cls, value, reader=None): - return tuple(reader.deserialize(obj) for obj in value['value']) + return tuple(reader.deserialize(obj) for obj in value["value"]) class UserTypeIO(GraphSONTypeIO): - prefix = 'dse' - graphson_base_type = 'UDT' - cql_type = 'udt' + prefix = "dse" + graphson_base_type = "UDT" + cql_type = "udt" FROZEN_REMOVAL_REGEX = re.compile(r'frozen<"*([^"]+)"*>') @classmethod def cql_types_from_string(cls, typ): # sanitizing: remove frozen references and double quotes... - return cql_types_from_string( - re.sub(cls.FROZEN_REMOVAL_REGEX, r'\1', typ) - ) + return cql_types_from_string(re.sub(cls.FROZEN_REMOVAL_REGEX, r"\1", typ)) @classmethod def get_udt_definition(cls, value, writer): user_type_name = writer.user_types[type(value)] - keyspace = writer.context['graph_name'] - return writer.context['cluster'].metadata.keyspaces[keyspace].user_types[user_type_name] + keyspace = writer.context["graph_name"] + return ( + writer.context["cluster"] + .metadata.keyspaces[keyspace] + .user_types[user_type_name] + ) @classmethod def is_collection(cls, typ): - return typ in ['list', 'tuple', 'map', 'set'] + return typ in ["list", "tuple", "map", "set"] @classmethod def is_udt(cls, typ, writer): - keyspace = writer.context['graph_name'] - if keyspace in writer.context['cluster'].metadata.keyspaces: - return typ in writer.context['cluster'].metadata.keyspaces[keyspace].user_types + keyspace = writer.context["graph_name"] + if keyspace in writer.context["cluster"].metadata.keyspaces: + return ( + typ in writer.context["cluster"].metadata.keyspaces[keyspace].user_types + ) return False @classmethod @@ -684,7 +731,7 @@ def field_definition(cls, types, writer, name=None): Build the udt field definition. This is required when we have a complex udt type. """ index = -1 - out = [OrderedDict() if name is None else OrderedDict([('fieldName', name)])] + out = [OrderedDict() if name is None else OrderedDict([("fieldName", name)])] while types: index += 1 @@ -693,52 +740,72 @@ def field_definition(cls, types, writer, name=None): out.append(OrderedDict()) if cls.is_udt(typ, writer): - keyspace = writer.context['graph_name'] - udt = writer.context['cluster'].metadata.keyspaces[keyspace].user_types[typ] + keyspace = writer.context["graph_name"] + udt = ( + writer.context["cluster"] + .metadata.keyspaces[keyspace] + .user_types[typ] + ) out[index].update(cls.definition(udt, writer)) elif cls.is_collection(typ): - out[index]['cqlType'] = typ + out[index]["cqlType"] = typ definition = cls.field_definition(types, writer) - out[index]['definition'] = definition if isinstance(definition, list) else [definition] + out[index]["definition"] = ( + definition if isinstance(definition, list) else [definition] + ) else: - out[index]['cqlType'] = typ + out[index]["cqlType"] = typ return out if len(out) > 1 else out[0] @classmethod def definition(cls, value, writer=None): - udt = value if isinstance(value, UserType) else cls.get_udt_definition(value, writer) - return OrderedDict([ - ('cqlType', cls.cql_type), - ('keyspace', udt.keyspace), - ('name', udt.name), - ('definition', [ - cls.field_definition(cls.cql_types_from_string(typ), writer, name=name) - for name, typ in zip(udt.field_names, udt.field_types)]) - ]) + udt = ( + value + if isinstance(value, UserType) + else cls.get_udt_definition(value, writer) + ) + return OrderedDict( + [ + ("cqlType", cls.cql_type), + ("keyspace", udt.keyspace), + ("name", udt.name), + ( + "definition", + [ + cls.field_definition( + cls.cql_types_from_string(typ), writer, name=name + ) + for name, typ in zip(udt.field_names, udt.field_types) + ], + ), + ] + ) @classmethod def serialize(cls, value, writer=None): udt = cls.get_udt_definition(value, writer) out = cls.definition(value, writer) - out['value'] = [] + out["value"] = [] for name, typ in zip(udt.field_names, udt.field_types): - out['value'].append(writer.serialize(getattr(value, name), writer)) + out["value"].append(writer.serialize(getattr(value, name), writer)) return out @classmethod def deserialize(cls, value, reader=None): - udt_class = reader.context['cluster']._user_types[value['keyspace']][value['name']] + udt_class = reader.context["cluster"]._user_types[value["keyspace"]][ + value["name"] + ] kwargs = zip( - list(map(lambda v: v['fieldName'], value['definition'])), - [reader.deserialize(v) for v in value['value']] + list(map(lambda v: v["fieldName"], value["definition"])), + [reader.deserialize(v) for v in value["value"]], ) return udt_class(**dict(kwargs)) class TTypeIO(GraphSONTypeIO): - prefix = 'g' - graphson_base_type = 'T' + prefix = "g" + graphson_base_type = "T" @classmethod def deserialize(cls, value, reader=None): @@ -746,7 +813,6 @@ def deserialize(cls, value, reader=None): class _BaseGraphSONSerializer(object): - _serializers = OrderedDict() @classmethod @@ -814,22 +880,24 @@ class GraphSON1Serializer(_BaseGraphSONSerializer): # When we fall back to a superclass's serializer, we iterate over this map. # We want that iteration order to be consistent, so we use an OrderedDict, # not a dict. - _serializers = OrderedDict([ - (str, TextTypeIO), - (bool, BooleanTypeIO), - (bytearray, ByteBufferTypeIO), - (Decimal, BigDecimalTypeIO), - (datetime.date, LocalDateTypeIO), - (datetime.time, LocalTimeTypeIO), - (datetime.timedelta, DurationTypeIO), - (datetime.datetime, InstantTypeIO), - (uuid.UUID, UUIDTypeIO), - (Polygon, PolygonTypeIO), - (Point, PointTypeIO), - (LineString, LineStringTypeIO), - (dict, JsonMapTypeIO), - (float, FloatTypeIO) - ]) + _serializers = OrderedDict( + [ + (str, TextTypeIO), + (bool, BooleanTypeIO), + (bytearray, ByteBufferTypeIO), + (Decimal, BigDecimalTypeIO), + (datetime.date, LocalDateTypeIO), + (datetime.time, LocalTimeTypeIO), + (datetime.timedelta, DurationTypeIO), + (datetime.datetime, InstantTypeIO), + (uuid.UUID, UUIDTypeIO), + (Polygon, PolygonTypeIO), + (Point, PointTypeIO), + (LineString, LineStringTypeIO), + (dict, JsonMapTypeIO), + (float, FloatTypeIO), + ] + ) GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO) @@ -839,7 +907,6 @@ class GraphSON1Serializer(_BaseGraphSONSerializer): class _BaseGraphSONDeserializer(object): - _deserializers = {} @classmethod @@ -855,7 +922,9 @@ def get_deserializer(cls, graphson_type): try: return cls._deserializers[graphson_type] except KeyError: - raise ValueError('Invalid `graphson_type` specified: {}'.format(graphson_type)) + raise ValueError( + "Invalid `graphson_type` specified: {}".format(graphson_type) + ) @classmethod def deserialize(cls, graphson_type, value): @@ -872,14 +941,23 @@ class GraphSON1Deserializer(_BaseGraphSONDeserializer): """ Deserialize graphson1 types to python objects. """ - _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO, ByteBufferTypeIO, - PointTypeIO, LineStringTypeIO, PolygonTypeIO, LocalDateTypeIO, - LocalTimeTypeIO, DurationTypeIO, InetTypeIO] - _deserializers = { - t.graphson_type: t - for t in _TYPES - } + _TYPES = [ + UUIDTypeIO, + BigDecimalTypeIO, + InstantTypeIO, + BlobTypeIO, + ByteBufferTypeIO, + PointTypeIO, + LineStringTypeIO, + PolygonTypeIO, + LocalDateTypeIO, + LocalTimeTypeIO, + DurationTypeIO, + InetTypeIO, + ] + + _deserializers = {t.graphson_type: t for t in _TYPES} @classmethod def deserialize_date(cls, value): @@ -969,7 +1047,9 @@ def serialize(self, value, writer=None): """ serializer = self.get_serializer(value) if not serializer: - raise ValueError("Unable to find a serializer for value of type: ".format(type(value))) + raise ValueError( + "Unable to find a serializer for value of type: {}".format(type(value)) + ) val = serializer.serialize(value, writer or self) if serializer is TypeWrapperTypeIO: @@ -993,16 +1073,23 @@ def serialize(self, value, writer=None): class GraphSON2Deserializer(_BaseGraphSONDeserializer): - _TYPES = GraphSON1Deserializer._TYPES + [ - Int16TypeIO, Int32TypeIO, Int64TypeIO, DoubleTypeIO, FloatTypeIO, - BigIntegerTypeIO, VertexTypeIO, VertexPropertyTypeIO, EdgeTypeIO, - PathTypeIO, PropertyTypeIO, TraversalMetricsTypeIO, MetricsTypeIO] + Int16TypeIO, + Int32TypeIO, + Int64TypeIO, + DoubleTypeIO, + FloatTypeIO, + BigIntegerTypeIO, + VertexTypeIO, + VertexPropertyTypeIO, + EdgeTypeIO, + PathTypeIO, + PropertyTypeIO, + TraversalMetricsTypeIO, + MetricsTypeIO, + ] - _deserializers = { - t.graphson_type: t - for t in _TYPES - } + _deserializers = {t.graphson_type: t for t in _TYPES} class GraphSON2Reader(object): @@ -1066,7 +1153,6 @@ def _wrap_value(type_io, value): class GraphSON3Serializer(GraphSON2Serializer): - _serializers = GraphSON2Serializer.get_type_definitions() context = None @@ -1084,17 +1170,23 @@ def get_serializer(self, value): """Custom get_serializer to support UDT/Tuple""" serializer = super(GraphSON3Serializer, self).get_serializer(value) - is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, '_fields') + is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, "_fields") if not serializer or is_namedtuple_udt: # Check if UDT if self.user_types is None: try: - user_types = self.context['cluster']._user_types[self.context['graph_name']] + user_types = self.context["cluster"]._user_types[ + self.context["graph_name"] + ] self.user_types = dict(map(reversed, user_types.items())) except KeyError: self.user_types = {} - serializer = UserTypeIO if (is_namedtuple_udt or (type(value) in self.user_types)) else serializer + serializer = ( + UserTypeIO + if (is_namedtuple_udt or (type(value) in self.user_types)) + else serializer + ) return serializer @@ -1108,10 +1200,16 @@ def get_serializer(self, value): class GraphSON3Deserializer(GraphSON2Deserializer): - _TYPES = GraphSON2Deserializer._TYPES + [MapTypeIO, ListTypeIO, - SetTypeIO, TupleTypeIO, - UserTypeIO, DseDurationTypeIO, - TTypeIO, BulkSetTypeIO] + _TYPES = GraphSON2Deserializer._TYPES + [ + MapTypeIO, + ListTypeIO, + SetTypeIO, + TupleTypeIO, + UserTypeIO, + DseDurationTypeIO, + TTypeIO, + BulkSetTypeIO, + ] _deserializers = {t.graphson_type: t for t in _TYPES} diff --git a/cassandra/encoder.py b/cassandra/encoder.py index d803c087ba..34045754db 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -18,6 +18,7 @@ """ import logging + log = logging.getLogger(__name__) from binascii import hexlify @@ -30,8 +31,17 @@ from uuid import UUID import ipaddress -from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, - sortedset, Time, Date, Point, LineString, Polygon) +from cassandra.util import ( + OrderedDict, + OrderedMap, + OrderedMapSerializedKey, + sortedset, + Time, + Date, + Point, + LineString, + Polygon, +) def cql_quote(term): @@ -83,28 +93,36 @@ def __init__(self): ValueSequence: self.cql_encode_sequence, Point: self.cql_encode_str_quoted, LineString: self.cql_encode_str_quoted, - Polygon: self.cql_encode_str_quoted + Polygon: self.cql_encode_str_quoted, } - self.mapping.update({ - memoryview: self.cql_encode_bytes, - bytes: self.cql_encode_bytes, - type(None): self.cql_encode_none, - ipaddress.IPv4Address: self.cql_encode_ipaddress, - ipaddress.IPv6Address: self.cql_encode_ipaddress - }) + self.mapping.update( + { + memoryview: self.cql_encode_bytes, + bytes: self.cql_encode_bytes, + type(None): self.cql_encode_none, + ipaddress.IPv4Address: self.cql_encode_ipaddress, + ipaddress.IPv6Address: self.cql_encode_ipaddress, + } + ) def cql_encode_none(self, val): """ Converts :const:`None` to the string 'NULL'. """ - return 'NULL' + return "NULL" def cql_encode_unicode(self, val): """ - Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. + Encodes a string value with quote escaping. + + .. deprecated:: + This method is unused internally since Python 2 support was + removed (``str`` is always unicode on Python 3). It is kept + for backward compatibility with user subclasses of + :class:`Encoder`. """ - return cql_quote(val.encode('utf-8')) + return cql_quote(val) def cql_encode_str(self, val): """ @@ -116,7 +134,7 @@ def cql_encode_str_quoted(self, val): return "'%s'" % val def cql_encode_bytes(self, val): - return (b'0x' + hexlify(val)).decode('utf-8') + return (b"0x" + hexlify(val)).decode("utf-8") def cql_encode_object(self, val): """ @@ -130,9 +148,9 @@ def cql_encode_float(self, val): Encode floats using repr to preserve precision """ if math.isinf(val): - return 'Infinity' if val > 0 else '-Infinity' + return "Infinity" if val > 0 else "-Infinity" elif math.isnan(val): - return 'NaN' + return "NaN" else: return repr(val) @@ -142,14 +160,14 @@ def cql_encode_datetime(self, val): with millisecond precision. """ timestamp = calendar.timegm(val.utctimetuple()) - return str(timestamp * 1000 + getattr(val, 'microsecond', 0) // 1000) + return str(timestamp * 1000 + getattr(val, "microsecond", 0) // 1000) def cql_encode_date(self, val): """ Converts a :class:`datetime.date` object to a string with format ``YYYY-MM-DD``. """ - return "'%s'" % val.strftime('%Y-%m-%d') + return "'%s'" % val.strftime("%Y-%m-%d") def cql_encode_time(self, val): """ @@ -163,15 +181,16 @@ def cql_encode_date_ext(self, val): Encodes a :class:`cassandra.util.Date` object as an integer """ # using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR - return str(val.days_from_epoch + 2 ** 31) + return str(val.days_from_epoch + 2**31) def cql_encode_sequence(self, val): """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``IN`` value lists. """ - return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) - for v in val) + return "(%s)" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) cql_encode_tuple = cql_encode_sequence """ @@ -184,24 +203,32 @@ def cql_encode_map_collection(self, val): Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. This is suitable for ``map`` type columns. """ - return '{%s}' % ', '.join('%s: %s' % ( - self.mapping.get(type(k), self.cql_encode_object)(k), - self.mapping.get(type(v), self.cql_encode_object)(v) - ) for k, v in val.items()) + return "{%s}" % ", ".join( + "%s: %s" + % ( + self.mapping.get(type(k), self.cql_encode_object)(k), + self.mapping.get(type(v), self.cql_encode_object)(v), + ) + for k, v in val.items() + ) def cql_encode_list_collection(self, val): """ Converts a sequence to a string of the form ``[item1, item2, ...]``. This is suitable for ``list`` type columns. """ - return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + return "[%s]" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) def cql_encode_set_collection(self, val): """ Converts a sequence to a string of the form ``{item1, item2, ...}``. This is suitable for ``set`` type columns. """ - return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) + return "{%s}" % ", ".join( + self.mapping.get(type(v), self.cql_encode_object)(v) for v in val + ) def cql_encode_all_types(self, val, as_text_type=False): """ @@ -210,7 +237,7 @@ def cql_encode_all_types(self, val, as_text_type=False): """ encoded = self.mapping.get(type(val), self.cql_encode_object)(val) if as_text_type and not isinstance(encoded, str): - return encoded.decode('utf-8') + return encoded.decode("utf-8") return encoded def cql_encode_ipaddress(self, val): @@ -221,4 +248,4 @@ def cql_encode_ipaddress(self, val): return "'%s'" % val.compressed def cql_encode_decimal(self, val): - return self.cql_encode_float(float(val)) \ No newline at end of file + return self.cql_encode_float(float(val)) diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index b43633d352..4dac8b7c3a 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -280,9 +280,9 @@ def _maybe_log_debug(self, *args, **kwargs): try: log.debug(*args, **kwargs) except Exception: - # TODO: Remove when Python 2 support is removed - # PYTHON-1266. If our logger has disappeared, there's nothing we - # can do, so just log nothing. + # PYTHON-1266. If our logger has disappeared (e.g. during + # interpreter shutdown), there's nothing we can do, so just + # log nothing. pass def add_timer(self, timer): diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 3d4a89a0b5..40765d668b 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2152,7 +2152,7 @@ def as_cql_query(self): class_name, ) if options: - # PYTHON-1008: `ret` will always be a unicode + # PYTHON-1008: `ret` will always be a str opts_cql_encoded = _encoder.cql_encode_all_types( options, as_text_type=True ) @@ -2347,7 +2347,7 @@ class BytesToken(Token): @classmethod def from_string(cls, token_string): """`token_string` should be the string representation from the server.""" - # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" + # unhexlify works fine with str input except on pypy3, where it raises "TypeError: 'str' does not support the buffer interface" if isinstance(token_string, str): token_string = token_string.encode("ascii") # The BOP stores a hex string diff --git a/cassandra/util.py b/cassandra/util.py index 4f5e9411b8..7893889f19 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -869,7 +869,8 @@ def __str__(self): inet_ntop = socket.inet_ntop -# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic +# similar to collections.namedtuple, reproduced here to handle invalid identifiers +# by renaming them to positional names def _positional_rename_invalid_identifiers(field_names): names_out = list(field_names) for index, name in enumerate(field_names):