diff --git a/.asf.yaml b/.asf.yaml new file mode 100644 index 0000000000..0bacf232d1 --- /dev/null +++ b/.asf.yaml @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +notifications: + commits: commits@cassandra.apache.org + issues: commits@cassandra.apache.org + pullrequests: pr@cassandra.apache.org + jira_options: link worklog + +github: + description: "Python Driver for Apache Cassandra®" + homepage: https://docs.datastax.com/en/developer/python-driver/3.29/index.html + enabled_merge_buttons: + squash: false + merge: false + rebase: true + features: + wiki: false + issues: false + projects: false + discussions: false + autolink_jira: + - CASSANDRA + - CASSPYTHON + protected_branches: + trunk: + required_linear_history: true diff --git a/.gitignore b/.gitignore index 5c9cbec957..7983f44b87 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,6 @@ build MANIFEST dist .coverage -nosetests.xml cover/ docs/_build/ tests/integration/ccm @@ -42,3 +41,6 @@ tests/unit/cython/bytesio_testhelper.c #iPython *.ipynb +venv +docs/venv +.eggs \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8e1e587229..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,31 +0,0 @@ -language: python -python: 2.7 -sudo: false -env: - - TOX_ENV=py26 CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py26 CASS_VER=21 - - TOX_ENV=py27 CASS_VER=12 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py27 CASS_VER=20 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py27 CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py27 CASS_VER=21 - - TOX_ENV=py33 CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py33 CASS_VER=21 - - TOX_ENV=py34 CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=py34 CASS_VER=21 - - TOX_ENV=pypy CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - - TOX_ENV=pypy3 CASS_VER=21 CASS_DRIVER_NO_CYTHON=1 - -addons: - apt: - packages: - - build-essential - - python-dev - - pypy-dev - - libev4 - - libev-dev - -install: - - pip install tox - -script: - - tox -e $TOX_ENV diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d002af7774..6da84ae7a4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,10 +1,903 @@ +3.29.3 +====== +October 20, 2025 + +Features +-------- +* Upgraded cython to 3.0.x (PR 1221 & PYTHON-1390) +* Add support for DSE 6.9.x and HCD releases to CI (PYTHON-1402) +* Add execute_concurrent_async and expose execute_concurrent_* in Session (PR 1229) + +Bug Fixes +--------- +* Update geomet to align with requirements.txt (PR 1236) +* Connection failure to SNI endpoint when first host is unavailable (PYTHON-1419) +* Maintain compatibility with CPython 3.13 (PR 1242) + +Others +------ +* Remove duplicated condition in primary key check (PR 1240) +* Remove Python 3.8 which reached EOL on Oct 2024, update cryptography lib to 42 (PR 1247) +* Remove obsolete urllib2 from ez_setup.py (PR 1248) +* Remove stale dependency on sure (PR 1227) +* Removed 2.7 Cpython defines (PR 1252) + +3.29.2 +====== +September 9, 2024 + +Features +-------- +* Convert to pytest for running unit and integration tests (PYTHON-1297) +* Add support for Cassandra 4.1.x and 5.0 releases to CI (PYTHON-1393) +* Extend driver vector support to arbitrary subtypes and fix handling of variable length types (PYTHON-1369) + +Bug Fixes +--------- +* Python NumpyProtocolHandler does not work with NumPy 1.24.0 or greater (PYTHON-1359) +* cibuildwheel appears to not be stripping Cython-generated shared objects (PYTHON-1387) +* Windows build for Python 3.12 compiled without libev support (PYTHON-1386) + +Others +------ +* Update README.rst with badges for version and license (PR 1210) +* Remove dependency on old mock external module (PR 1201) +* Removed future print_function, division, and with and some pre 3.7 handling (PR 1208) +* Update geomet dependency (PR 1207) +* Remove problematic escape sequences in some docstrings to avoid SyntaxWarning in Python 3.12 (PR 1205) +* Use timezone-aware API to avoid deprecated warning (PR 1213) + +3.29.1 +====== +March 19, 2024 + +Bug Fixes +--------- +* cassandra-driver for Python 3.12 Linux is compiled without libev support (PYTHON-1378) +* Consider moving to native wheel builds for OS X and removing universal2 wheels (PYTHON-1379) + +3.29.0 +====== +December 19, 2023 + +Features +-------- +* Add support for Python 3.9 through 3.12, drop support for 3.7 (PYTHON-1283) +* Removal of dependency on six module (PR 1172) +* Raise explicit exception when deserializing a vector with a subtype that isn’t a constant size (PYTHON-1371) + +Others +------ +* Remove outdated Python pre-3.7 references (PR 1186) +* Remove backup(.bak) files (PR 1185) +* Fix doc typo in add_callbacks (PR 1177) + +3.28.0 +====== +June 5, 2023 + +Features +-------- +* Add support for vector type (PYTHON-1352) +* Cryptography module is now an optional dependency (PYTHON-1351) + +Bug Fixes +--------- +* Store IV along with encrypted text when using column-level encryption (PYTHON-1350) +* Create session-specific protocol handlers to contain session-specific CLE policies (PYTHON-1356) + +Others +------ +* Use Cython for smoke builds (PYTHON-1343) +* Don't fail when inserting UDTs with prepared queries with some missing fields (PR 1151) +* Convert print statement to function in docs (PR 1157) +* Update comment for retry policy (DOC-3278) +* Added error handling blog reference (DOC-2813) + +3.27.0 +====== +May 1, 2023 + +Features +-------- +* Add support for client-side encryption (PYTHON-1341) + +3.26.0 +====== +March 13, 2023 + +Features +-------- +* Add support for execution profiles in execute_concurrent (PR 1122) + +Bug Fixes +--------- +* Handle empty non-final result pages (PR 1110) +* Do not re-use stream IDs for in-flight requests (PR 1114) +* Asyncore race condition cause logging exception on shutdown (PYTHON-1266) + +Others +------ +* Fix deprecation warning in query tracing (PR 1103) +* Remove mutable default values from some tests (PR 1116) +* Remove dependency on unittest2 (PYTHON-1289) +* Fix deprecation warnings for asyncio.coroutine annotation in asyncioreactor (PYTHON-1290) +* Fix typos in source files (PR 1126) +* HostFilterPolicyInitTest fix for Python 3.11 (PR 1131) +* Fix for DontPrepareOnIgnoredHostsTest (PYTHON-1287) +* tests.integration.simulacron.test_connection failures (PYTHON-1304) +* tests.integration.standard.test_single_interface.py appears to be failing for C* 4.0 (PYTHON-1329) +* Authentication tests appear to be failing fraudulently (PYTHON-1328) +* PreparedStatementTests.test_fail_if_different_query_id_on_reprepare() failing unexpectedly (PTYHON-1327) +* Refactor deprecated unittest aliases for Python 3.11 compatibility (PR 1112) + +Deprecations +------------ +* This release removes support for Python 2.7.x as well as Python 3.5.x and 3.6.x + +3.25.0 +====== +March 18, 2021 + +Features +-------- +* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260) +* Implement protocol v5 checksumming (PYTHON-1258) +* Fix the default cqlengine connection mechanism to work with Astra (PYTHON-1265) + +Bug Fixes +--------- +* Asyncore race condition cause logging exception on shutdown (PYTHON-1266) +* Update list of reserved keywords (PYTHON-1269) + +Others +------ +* Drop Python 3.4 support (PYTHON-1220) +* Update security documentation and examples to use PROTOCOL_TLS (PYTHON-1264) + +3.24.0 +====== +June 18, 2020 + +Features +-------- +* Make geomet an optional dependency at runtime (PYTHON-1237) +* Add use_default_tempdir cloud config options (PYTHON-1245) +* Tcp flow control for libevreactor (PYTHON-1248) + +Bug Fixes +--------- +* Unable to connect to a cloud cluster using Ubuntu 20.04 (PYTHON-1238) +* PlainTextAuthProvider fails with unicode chars and Python3 (PYTHON-1241) +* [GRAPH] Graph execution profiles consistency level are not set to LOCAL_QUORUM with a cloud cluster (PYTHON-1240) +* [GRAPH] Can't write data in a Boolean field using the Fluent API (PYTHON-1239) +* [GRAPH] Fix elementMap() result deserialization (PYTHON-1233) + +Others +------ +* Bump geomet dependency version to 0.2 (PYTHON-1243) +* Bump gremlinpython dependency version to 3.4.6 (PYTHON-1212) +* Improve fluent graph documentation for core graphs (PYTHON-1244) + +3.23.0 +====== +April 6, 2020 + +Features +-------- +* Transient Replication Support (PYTHON-1207) +* Support system.peers_v2 and port discovery for C* 4.0 (PYTHON-700) + +Bug Fixes +--------- +* Asyncore logging exception on shutdown (PYTHON-1228) + +3.22.0 +====== +February 26, 2020 + +Features +-------- + +* Add all() function to the ResultSet API (PYTHON-1203) +* Parse new schema metadata in NGDG and generate table edges CQL syntax (PYTHON-996) +* Add GraphSON3 support (PYTHON-788) +* Use GraphSON3 as default for Native graphs (PYTHON-1004) +* Add Tuple and UDT types for native graph (PYTHON-1005) +* Add Duration type for native graph (PYTHON-1000) +* Add gx:ByteBuffer graphson type support for Blob field (PYTHON-1027) +* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045) +* Provide numerical wrappers to ensure proper graphson schema definition (PYTHON-1051) +* Resolve the row_factory automatically for native graphs (PYTHON-1056) +* Add g:TraversalMetrics/g:Metrics graph deserializers (PYTHON-1057) +* Add g:BulkSet graph deserializers (PYTHON-1060) +* Update Graph Engine names and the way to create a Classic/Native Graph (PYTHON-1090) +* Update Native to Core Graph Engine +* Add graphson3 and native graph support (PYTHON-1039) +* Enable Paging Through DSE Driver for Gremlin Traversals (PYTHON-1045) +* Expose filter predicates for cql collections (PYTHON-1019) +* Add g:TraversalMetrics/Metrics deserializers (PYTHON-1057) +* Make graph metadata handling more robust (PYTHON-1204) + +Bug Fixes +--------- +* Make sure to only query the native_transport_address column with DSE (PYTHON-1205) + +3.21.0 +====== +January 15, 2020 + +Features +-------- +* Unified driver: merge core and DSE drivers into a single package (PYTHON-1130) +* Add Python 3.8 support (PYTHON-1189) +* Allow passing ssl context for Twisted (PYTHON-1161) +* Ssl context and cloud support for Eventlet (PYTHON-1162) +* Cloud Twisted support (PYTHON-1163) +* Add additional_write_policy and read_repair to system schema parsing (PYTHON-1048) +* Flexible version parsing (PYTHON-1174) +* Support NULL in collection deserializer (PYTHON-1123) +* [GRAPH] Ability to execute Fluent Graph queries asynchronously (PYTHON-1129) + +Bug Fixes +--------- +* Handle prepared id mismatch when repreparing on the fly (PYTHON-1124) +* re-raising the CQLEngineException will fail on Python 3 (PYTHON-1166) +* asyncio message chunks can be processed discontinuously (PYTHON-1185) +* Reconnect attempts persist after downed node removed from peers (PYTHON-1181) +* Connection fails to validate ssl certificate hostname when SSLContext.check_hostname is set (PYTHON-1186) +* ResponseFuture._set_result crashes on connection error when used with PrepareMessage (PYTHON-1187) +* Insights fail to serialize the startup message when the SSL Context is from PyOpenSSL (PYTHON-1192) + +Others +------ +* The driver has a new dependency: geomet. It comes from the dse-driver unification and + is used to support DSE geo types. +* Remove *read_repair_chance table options (PYTHON-1140) +* Avoid warnings about unspecified load balancing policy when connecting to a cloud cluster (PYTHON-1177) +* Add new DSE CQL keywords (PYTHON-1122) +* Publish binary wheel distributions (PYTHON-1013) + +Deprecations +------------ + +* DSELoadBalancingPolicy will be removed in the next major, consider using + the DefaultLoadBalancingPolicy. + +Merged from dse-driver: + +Features +-------- + +* Insights integration (PYTHON-1047) +* Graph execution profiles should preserve their graph_source when graph_options is overridden (PYTHON-1021) +* Add NodeSync metadata (PYTHON-799) +* Add new NodeSync failure values (PYTHON-934) +* DETERMINISTIC and MONOTONIC Clauses for Functions and Aggregates (PYTHON-955) +* GraphOptions should show a warning for unknown parameters (PYTHON-819) +* DSE protocol version 2 and continous paging backpressure (PYTHON-798) +* GraphSON2 Serialization/Deserialization Support (PYTHON-775) +* Add graph-results payload option for GraphSON format (PYTHON-773) +* Create an AuthProvider for the DSE transitional mode (PYTHON-831) +* Implement serializers for the Graph String API (PYTHON-778) +* Provide deserializers for GraphSON types (PYTHON-782) +* Add Graph DurationType support (PYTHON-607) +* Support DSE DateRange type (PYTHON-668) +* RLAC CQL output for materialized views (PYTHON-682) +* Add Geom Types wkt deserializer +* DSE Graph Client timeouts in custom payload (PYTHON-589) +* Make DSEGSSAPIAuthProvider accept principal name (PYTHON-574) +* Add config profiles to DSE graph execution (PYTHON-570) +* DSE Driver version checking (PYTHON-568) +* Distinct default timeout for graph queries (PYTHON-477) +* Graph result parsing for known types (PYTHON-479,487) +* Distinct read/write CL for graph execution (PYTHON-509) +* Target graph analytics query to spark master when available (PYTHON-510) + +Bug Fixes +--------- + +* Continuous paging sessions raise RuntimeError when results are not entirely consumed (PYTHON-1054) +* GraphSON Property deserializer should return a dict instead of a set (PYTHON-1033) +* ResponseFuture.has_more_pages may hold the wrong value (PYTHON-946) +* DETERMINISTIC clause in AGGREGATE misplaced in CQL generation (PYTHON-963) +* graph module import cause a DLL issue on Windows due to its cythonizing failure (PYTHON-900) +* Update date serialization to isoformat in graph (PYTHON-805) +* DateRange Parse Error (PYTHON-729) +* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728) +* metadata.get_host returning None unexpectedly (PYTHON-709) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* Resolve FQDN from ip address and use that as host passed to SASLClient (PYTHON-566) +* Geospatial type implementations don't handle 'EMPTY' values. (PYTHON-481) +* Correctly handle other types in geo type equality (PYTHON-508) + +Other +----- +* Add tests around cqlengine and continuous paging (PYTHON-872) +* Add an abstract GraphStatement to handle different graph statements (PYTHON-789) +* Write documentation examples for DSE 2.0 features (PYTHON-732) +* DSE_V1 protocol should not include all of protocol v5 (PYTHON-694) + +3.20.2 +====== +November 19, 2019 + +Bug Fixes +--------- +* Fix import error for old python installation without SSLContext (PYTHON-1183) + +3.20.1 +====== +November 6, 2019 + +Bug Fixes +--------- +* ValueError: too many values to unpack (expected 2)" when there are two dashes in server version number (PYTHON-1172) + +3.20.0 +====== +October 28, 2019 + +Features +-------- +* DataStax Astra Support (PYTHON-1074) +* Use 4.0 schema parser in 4 alpha and snapshot builds (PYTHON-1158) + +Bug Fixes +--------- +* Connection setup methods prevent using ExecutionProfile in cqlengine (PYTHON-1009) +* Driver deadlock if all connections dropped by heartbeat whilst request in flight and request times out (PYTHON-1044) +* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121) + +3.19.0 +====== +August 26, 2019 + +Features +-------- +* Add Python 3.7 support (PYTHON-1016) +* Future-proof Mapping imports (PYTHON-1023) +* Include param values in cqlengine logging (PYTHON-1105) +* NTS Token Replica Map Generation is slow (PYTHON-622) + +Bug Fixes +--------- +* as_cql_query UDF/UDA parameters incorrectly includes "frozen" if arguments are collections (PYTHON-1031) +* cqlengine does not currently support combining TTL and TIMESTAMP on INSERT (PYTHON-1093) +* Fix incorrect metadata for compact counter tables (PYTHON-1100) +* Call ConnectionException with correct kwargs (PYTHON-1117) +* Can't connect to clusters built from source because version parsing doesn't handle 'x.y-SNAPSHOT' (PYTHON-1118) +* Discovered node doesn´t honor the configured Cluster port on connection (PYTHON-1127) +* Exception when use pk__token__gt filter In python 3.7 (PYTHON-1121) + +Other +----- +* Remove invalid warning in set_session when we initialize a default connection (PYTHON-1104) +* Set the proper default ExecutionProfile.row_factory value (PYTHON-1119) + +3.18.0 +====== +May 27, 2019 + +Features +-------- + +* Abstract Host Connection information (PYTHON-1079) +* Improve version parsing to support a non-integer 4th component (PYTHON-1091) +* Expose on_request_error method in the RetryPolicy (PYTHON-1064) +* Add jitter to ExponentialReconnectionPolicy (PYTHON-1065) + +Bug Fixes +--------- + +* Fix error when preparing queries with beta protocol v5 (PYTHON-1081) +* Accept legacy empty strings as column names (PYTHON-1082) +* Let util.SortedSet handle uncomparable elements (PYTHON-1087) + +3.17.1 +====== +May 2, 2019 + +Bug Fixes +--------- +* Socket errors EAGAIN/EWOULDBLOCK are not handled properly and cause timeouts (PYTHON-1089) + +3.17.0 +====== +February 19, 2019 + +Features +-------- +* Send driver name and version in startup message (PYTHON-1068) +* Add Cluster ssl_context option to enable SSL (PYTHON-995) +* Allow encrypted private keys for 2-way SSL cluster connections (PYTHON-995) +* Introduce new method ConsistencyLevel.is_serial (PYTHON-1067) +* Add Session.get_execution_profile (PYTHON-932) +* Add host kwarg to Session.execute/execute_async APIs to send a query to a specific node (PYTHON-993) + +Bug Fixes +--------- +* NoHostAvailable when all hosts are up and connectable (PYTHON-891) +* Serial consistency level is not used (PYTHON-1007) + +Other +----- +* Fail faster on incorrect lz4 import (PYTHON-1042) +* Bump Cython dependency version to 0.29 (PYTHON-1036) +* Expand Driver SSL Documentation (PYTHON-740) + +Deprecations +------------ + +* Using Cluster.ssl_options to enable SSL is deprecated and will be removed in + the next major release, use ssl_context. +* DowngradingConsistencyRetryPolicy is deprecated and will be + removed in the next major release. (PYTHON-937) + +3.16.0 +====== +November 12, 2018 + +Bug Fixes +--------- +* Improve and fix socket error-catching code in nonblocking-socket reactors (PYTHON-1024) +* Non-ASCII characters in schema break CQL string generation (PYTHON-1008) +* Fix OSS driver's virtual table support against DSE 6.0.X and future server releases (PYTHON-1020) +* ResultSet.one() fails if the row_factory is using a generator (PYTHON-1026) +* Log profile name on attempt to create existing profile (PYTHON-944) +* Cluster instantiation fails if any contact points' hostname resolution fails (PYTHON-895) + +Other +----- +* Fix tests when RF is not maintained if we decomission a node (PYTHON-1017) +* Fix wrong use of ResultSet indexing (PYTHON-1015) + +3.15.1 +====== +September 6, 2018 + +Bug Fixes +--------- +* C* 4.0 schema-parsing logic breaks running against DSE 6.0.X (PYTHON-1018) + +3.15.0 +====== +August 30, 2018 + +Features +-------- +* Parse Virtual Keyspace Metadata (PYTHON-992) + +Bug Fixes +--------- +* Tokenmap.get_replicas returns the wrong value if token coincides with the end of the range (PYTHON-978) +* Python Driver fails with "more than 255 arguments" python exception when > 255 columns specified in query response (PYTHON-893) +* Hang in integration.standard.test_cluster.ClusterTests.test_set_keyspace_twice (PYTHON-998) +* Asyncore reactors should use a global variable instead of a class variable for the event loop (PYTHON-697) + +Other +----- +* Use global variable for libev loops so it can be subclassed (PYTHON-973) +* Update SchemaParser for V4 (PYTHON-1006) +* Bump Cython dependency version to 0.28 (PYTHON-1012) + +3.14.0 +====== +April 17, 2018 + +Features +-------- +* Add one() function to the ResultSet API (PYTHON-947) +* Create an utility function to fetch concurrently many keys from the same replica (PYTHON-647) +* Allow filter queries with fields that have an index managed outside of cqlengine (PYTHON-966) +* Twisted SSL Support (PYTHON-343) +* Support IS NOT NULL operator in cqlengine (PYTHON-968) + +Other +----- +* Fix Broken Links in Docs (PYTHON-916) +* Reevaluate MONKEY_PATCH_LOOP in test codebase (PYTHON-903) +* Remove CASS_SERVER_VERSION and replace it for CASSANDRA_VERSION in tests (PYTHON-910) +* Refactor CASSANDRA_VERSION to a some kind of version object (PYTHON-915) +* Log warning when driver configures an authenticator, but server does not request authentication (PYTHON-940) +* Warn users when using the deprecated Session.default_consistency_level (PYTHON-953) +* Add DSE smoke test to OSS driver tests (PYTHON-894) +* Document long compilation times and workarounds (PYTHON-868) +* Improve error for batch WriteTimeouts (PYTHON-941) +* Deprecate ResultSet indexing (PYTHON-945) + +3.13.0 +====== +January 30, 2018 + +Features +-------- +* cqlengine: LIKE filter operator (PYTHON-512) +* Support cassandra.query.BatchType with cqlengine BatchQuery (PYTHON-888) + +Bug Fixes +--------- +* AttributeError: 'NoneType' object has no attribute 'add_timer' (PYTHON-862) +* Support retry_policy in PreparedStatement (PYTHON-861) +* __del__ method in Session is throwing an exception (PYTHON-813) +* LZ4 import issue with recent versions (PYTHON-897) +* ResponseFuture._connection can be None when returning request_id (PYTHON-853) +* ResultSet.was_applied doesn't support batch with LWT statements (PYTHON-848) + +Other +----- +* cqlengine: avoid warning when unregistering connection on shutdown (PYTHON-865) +* Fix DeprecationWarning of log.warn (PYTHON-846) +* Fix example_mapper.py for python3 (PYTHON-860) +* Possible deadlock on cassandra.concurrent.execute_concurrent (PYTHON-768) +* Add some known deprecated warnings for 4.x (PYTHON-877) +* Remove copyright dates from copyright notices (PYTHON-863) +* Remove "Experimental" tag from execution profiles documentation (PYTHON-840) +* request_timer metrics descriptions are slightly incorrect (PYTHON-885) +* Remove "Experimental" tag from cqlengine connections documentation (PYTHON-892) +* Set in documentation default consistency for operations is LOCAL_ONE (PYTHON-901) + +3.12.0 +====== +November 6, 2017 + +Features +-------- +* Send keyspace in QUERY, PREPARE, and BATCH messages (PYTHON-678) +* Add IPv4Address/IPv6Address support for inet types (PYTHON-751) +* WriteType.CDC and VIEW missing (PYTHON-794) +* Warn on Cluster init if contact points are specified but LBP isn't (legacy mode) (PYTHON-812) +* Warn on Cluster init if contact points are specified but LBP isn't (exection profile mode) (PYTHON-838) +* Include hash of result set metadata in prepared stmt id (PYTHON-808) +* Add NO_COMPACT startup option (PYTHON-839) +* Add new exception type for CDC (PYTHON-837) +* Allow 0ms in ConstantSpeculativeExecutionPolicy (PYTHON-836) +* Add asyncio reactor (PYTHON-507) + +Bug Fixes +--------- +* Both _set_final_exception/result called for the same ResponseFuture (PYTHON-630) +* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781) +* Not create two sessions by default in CQLEngine (PYTHON-814) +* Bug when subclassing AyncoreConnection (PYTHON-827) +* Error at cleanup when closing the asyncore connections (PYTHON-829) +* Fix sites where `sessions` can change during iteration (PYTHON-793) +* cqlengine: allow min_length=0 for Ascii and Text column types (PYTHON-735) +* Rare exception when "sys.exit(0)" after query timeouts (PYTHON-752) +* Dont set the session keyspace when preparing statements (PYTHON-843) +* Use of DCAwareRoundRobinPolicy raises NoHostAvailable exception (PYTHON-781) + +Other +------ +* Remove DeprecationWarning when using WhiteListRoundRobinPolicy (PYTHON-810) +* Bump Cython dependency version to 0.27 (PYTHON-833) + +3.11.0 +====== +July 24, 2017 + + +Features +-------- +* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762) +* Add HostFilterPolicy (PYTHON-761) + +Bug Fixes +--------- +* is_idempotent flag is not propagated from PreparedStatement to BoundStatement (PYTHON-736) +* Fix asyncore hang on exit (PYTHON-767) +* Driver takes several minutes to remove a bad host from session (PYTHON-762) +* Installation doesn't always fall back to no cython in Windows (PYTHON-763) +* Avoid to replace a connection that is supposed to shutdown (PYTHON-772) +* request_ids may not be returned to the pool (PYTHON-739) +* Fix murmur3 on big-endian systems (PYTHON-653) +* Ensure unused connections are closed if a Session is deleted by the GC (PYTHON-774) +* Fix .values_list by using db names internally (cqlengine) (PYTHON-785) + + +Other +----- +* Bump Cython dependency version to 0.25.2 (PYTHON-754) +* Fix DeprecationWarning when using lz4 (PYTHON-769) +* Deprecate WhiteListRoundRobinPolicy (PYTHON-759) +* Improve upgrade guide for materializing pages (PYTHON-464) +* Documentation for time/date specifies timestamp inupt as microseconds (PYTHON-717) +* Point to DSA Slack, not IRC, in docs index + +3.10.0 +====== +May 24, 2017 + +Features +-------- +* Add Duration type to cqlengine (PYTHON-750) +* Community PR review: Raise error on primary key update only if its value changed (PYTHON-705) +* get_query_trace() contract is ambiguous (PYTHON-196) + +Bug Fixes +--------- +* Queries using speculative execution policy timeout prematurely (PYTHON-755) +* Fix `map` where results are not consumed (PYTHON-749) +* Driver fails to encode Duration's with large values (PYTHON-747) +* UDT values are not updated correctly in CQLEngine (PYTHON-743) +* UDT types are not validated in CQLEngine (PYTHON-742) +* to_python is not implemented for types columns.Type and columns.Date in CQLEngine (PYTHON-741) +* Clients spin infinitely trying to connect to a host that is drained (PYTHON-734) +* Resulset.get_query_trace returns empty trace sometimes (PYTHON-730) +* Memory grows and doesn't get removed (PYTHON-720) +* Fix RuntimeError caused by change dict size during iteration (PYTHON-708) +* fix ExponentialReconnectionPolicy may throw OverflowError problem (PYTHON-707) +* Avoid using nonexistent prepared statement in ResponseFuture (PYTHON-706) + +Other +----- +* Update README (PYTHON-746) +* Test python versions 3.5 and 3.6 (PYTHON-737) +* Docs Warning About Prepare "select *" (PYTHON-626) +* Increase Coverage in CqlEngine Test Suite (PYTHON-505) +* Example SSL connection code does not verify server certificates (PYTHON-469) + +3.9.0 +===== + +Features +-------- +* cqlengine: remove elements by key from a map (PYTHON-688) + +Bug Fixes +--------- +* improve error handling when connecting to non-existent keyspace (PYTHON-665) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* rare flake on integration.standard.test_cluster.ClusterTests.test_clone_shared_lbp (PYTHON-727) +* MontonicTimestampGenerator.__init__ ignores class defaults (PYTHON-728) +* race where callback or errback for request may not be called (PYTHON-733) +* cqlengine: model.update() should not update columns with a default value that hasn't changed (PYTHON-657) +* cqlengine: field value manager's explicit flag is True when queried back from cassandra (PYTHON-719) + +Other +----- +* Connection not closed in example_mapper (PYTHON-723) +* Remove mention of pre-2.0 C* versions from OSS 3.0+ docs (PYTHON-710) + +3.8.1 +===== +March 16, 2017 + +Bug Fixes +--------- + +* implement __le__/__ge__/__ne__ on some custom types (PYTHON-714) +* Fix bug in eventlet and gevent reactors that could cause hangs (PYTHON-721) +* Fix DecimalType regression (PYTHON-724) + +3.8.0 +===== + +Features +-------- + +* Quote index names in metadata CQL generation (PYTHON-616) +* On column deserialization failure, keep error message consistent between python and cython (PYTHON-631) +* TokenAwarePolicy always sends requests to the same replica for a given key (PYTHON-643) +* Added cql types to result set (PYTHON-648) +* Add __len__ to BatchStatement (PYTHON-650) +* Duration Type for Cassandra (PYTHON-655) +* Send flags with PREPARE message in v5 (PYTHON-684) + +Bug Fixes +--------- + +* Potential Timing issue if application exits prior to session pool initialization (PYTHON-636) +* "Host X.X.X.X has been marked down" without any exceptions (PYTHON-640) +* NoHostAvailable or OperationTimedOut when using execute_concurrent with a generator that inserts into more than one table (PYTHON-642) +* ResponseFuture creates Timers and don't cancel them even when result is received which leads to memory leaks (PYTHON-644) +* Driver cannot connect to Cassandra version > 3 (PYTHON-646) +* Unable to import model using UserType without setuping connection since 3.7 (PYTHON-649) +* Don't prepare queries on ignored hosts on_up (PYTHON-669) +* Sockets associated with sessions not getting cleaned up on session.shutdown() (PYTHON-673) +* Make client timestamps strictly monotonic (PYTHON-676) +* cassandra.cqlengine.connection.register_connection broken when hosts=None (PYTHON-692) + +Other +----- + +* Create a cqlengine doc section explaining None semantics (PYTHON-623) +* Resolve warnings in documentation generation (PYTHON-645) +* Cython dependency (PYTHON-686) +* Drop Support for Python 2.6 (PYTHON-690) + +3.7.1 +===== +October 26, 2016 + +Bug Fixes +--------- +* Cython upgrade has broken stable version of cassandra-driver (PYTHON-656) + +3.7.0 +===== +September 13, 2016 + +Features +-------- +* Add v5 protocol failure map (PYTHON-619) +* Don't return from initial connect on first error (PYTHON-617) +* Indicate failed column when deserialization fails (PYTHON-361) +* Let Cluster.refresh_nodes force a token map rebuild (PYTHON-349) +* Refresh UDTs after "keyspace updated" event with v1/v2 protocol (PYTHON-106) +* EC2 Address Resolver (PYTHON-198) +* Speculative query retries (PYTHON-218) +* Expose paging state in API (PYTHON-200) +* Don't mark host down while one connection is active (PYTHON-498) +* Query request size information (PYTHON-284) +* Avoid quadratic ring processing with invalid replication factors (PYTHON-379) +* Improve Connection/Pool creation concurrency on startup (PYTHON-82) +* Add beta version native protocol flag (PYTHON-614) +* cqlengine: Connections: support of multiple keyspaces and sessions (PYTHON-613) + +Bug Fixes +--------- +* Race when adding a pool while setting keyspace (PYTHON-628) +* Update results_metadata when prepared statement is reprepared (PYTHON-621) +* CQL Export for Thrift Tables (PYTHON-213) +* cqlengine: default value not applied to UserDefinedType (PYTHON-606) +* cqlengine: columns are no longer hashable (PYTHON-618) +* cqlengine: remove clustering keys from where clause when deleting only static columns (PYTHON-608) + +3.6.0 +===== +August 1, 2016 + +Features +-------- +* Handle null values in NumpyProtocolHandler (PYTHON-553) +* Collect greplin scales stats per cluster (PYTHON-561) +* Update mock unit test dependency requirement (PYTHON-591) +* Handle Missing CompositeType metadata following C* upgrade (PYTHON-562) +* Improve Host.is_up state for HostDistance.IGNORED hosts (PYTHON-551) +* Utilize v2 protocol's ability to skip result set metadata for prepared statement execution (PYTHON-71) +* Return from Cluster.connect() when first contact point connection(pool) is opened (PYTHON-105) +* cqlengine: Add ContextQuery to allow cqlengine models to switch the keyspace context easily (PYTHON-598) +* Standardize Validation between Ascii and Text types in Cqlengine (PYTHON-609) + +Bug Fixes +--------- +* Fix geventreactor with SSL support (PYTHON-600) +* Don't downgrade protocol version if explicitly set (PYTHON-537) +* Nonexistent contact point tries to connect indefinitely (PYTHON-549) +* Execute_concurrent can exceed max recursion depth in failure mode (PYTHON-585) +* Libev loop shutdown race (PYTHON-578) +* Include aliases in DCT type string (PYTHON-579) +* cqlengine: Comparison operators for Columns (PYTHON-595) +* cqlengine: disentangle default_time_to_live table option from model query default TTL (PYTHON-538) +* cqlengine: pk__token column name issue with the equality operator (PYTHON-584) +* cqlengine: Fix "__in" filtering operator converts True to string "True" automatically (PYTHON-596) +* cqlengine: Avoid LWTExceptions when updating columns that are part of the condition (PYTHON-580) +* cqlengine: Cannot execute a query when the filter contains all columns (PYTHON-599) +* cqlengine: routing key computation issue when a primary key column is overriden by model inheritance (PYTHON-576) + +3.5.0 +===== +June 27, 2016 + +Features +-------- +* Optional Execution Profiles for the core driver (PYTHON-569) +* API to get the host metadata associated with the control connection node (PYTHON-583) +* Expose CDC option in table metadata CQL (PYTHON-593) + +Bug Fixes +--------- +* Clean up Asyncore socket map when fork is detected (PYTHON-577) +* cqlengine: QuerySet only() is not respected when there are deferred fields (PYTHON-560) + +3.4.1 +===== +May 26, 2016 + +Bug Fixes +--------- +* Gevent connection closes on IO timeout (PYTHON-573) +* "dictionary changed size during iteration" with Python 3 (PYTHON-572) + +3.4.0 +===== +May 24, 2016 + +Features +-------- +* Include DSE version and workload in Host data (PYTHON-555) +* Add a context manager to Cluster and Session (PYTHON-521) +* Better Error Message for Unsupported Protocol Version (PYTHON-157) +* Make the error message explicitly state when an error comes from the server (PYTHON-412) +* Short Circuit meta refresh on topo change if NEW_NODE already exists (PYTHON-557) +* Show warning when the wrong config is passed to SimpleStatement (PYTHON-219) +* Return namedtuple result pairs from execute_concurrent (PYTHON-362) +* BatchStatement should enforce batch size limit in a better way (PYTHON-151) +* Validate min/max request thresholds for connection pool scaling (PYTHON-220) +* Handle or warn about multiple hosts with the same rpc_address (PYTHON-365) +* Write docs around working with datetime and timezones (PYTHON-394) + +Bug Fixes +--------- +* High CPU utilization when using asyncore event loop (PYTHON-239) +* Fix CQL Export for non-ASCII Identifiers (PYTHON-447) +* Make stress scripts Python 2.6 compatible (PYTHON-434) +* UnicodeDecodeError when unicode characters in key in BOP (PYTHON-559) +* WhiteListRoundRobinPolicy should resolve hosts (PYTHON-565) +* Cluster and Session do not GC after leaving scope (PYTHON-135) +* Don't wait for schema agreement on ignored nodes (PYTHON-531) +* Reprepare on_up with many clients causes node overload (PYTHON-556) +* None inserted into host map when control connection node is decommissioned (PYTHON-548) +* weakref.ref does not accept keyword arguments (github #585) + +3.3.0 +===== +May 2, 2016 + +Features +-------- +* Add an AddressTranslator interface (PYTHON-69) +* New Retry Policy Decision - try next host (PYTHON-285) +* Don't mark host down on timeout (PYTHON-286) +* SSL hostname verification (PYTHON-296) +* Add C* version to metadata or cluster objects (PYTHON-301) +* Options to Disable Schema, Token Metadata Processing (PYTHON-327) +* Expose listen_address of node we get ring information from (PYTHON-332) +* Use A-record with multiple IPs for contact points (PYTHON-415) +* Custom consistency level for populating query traces (PYTHON-435) +* Normalize Server Exception Types (PYTHON-443) +* Propagate exception message when DDL schema agreement fails (PYTHON-444) +* Specialized exceptions for metadata refresh methods failure (PYTHON-527) + +Bug Fixes +--------- +* Resolve contact point hostnames to avoid duplicate hosts (PYTHON-103) +* GeventConnection stalls requests when read is a multiple of the input buffer size (PYTHON-429) +* named_tuple_factory breaks with duplicate "cleaned" col names (PYTHON-467) +* Connection leak if Cluster.shutdown() happens during reconnection (PYTHON-482) +* HostConnection.borrow_connection does not block when all request ids are used (PYTHON-514) +* Empty field not being handled by the NumpyProtocolHandler (PYTHON-550) + +3.2.2 +===== +April 19, 2016 + +* Fix counter save-after-no-update (PYTHON-547) + +3.2.1 +===== +April 13, 2016 + +* Introduced an update to allow deserializer compilation with recently released Cython 0.24 (PYTHON-542) + +3.2.0 +===== +April 12, 2016 + +Features +-------- +* cqlengine: Warn on sync_schema type mismatch (PYTHON-260) +* cqlengine: Automatically defer fields with the '=' operator (and immutable values) in select queries (PYTHON-520) +* cqlengine: support non-equal conditions for LWT (PYTHON-528) +* cqlengine: sync_table should validate the primary key composition (PYTHON-532) +* cqlengine: token-aware routing for mapper statements (PYTHON-535) + +Bug Fixes +--------- +* Deleting a column in a lightweight transaction raises a SyntaxException #325 (PYTHON-249) +* cqlengine: make Token function works with named tables/columns #86 (PYTHON-272) +* comparing models with datetime fields fail #79 (PYTHON-273) +* cython date deserializer integer math should be aligned with CPython (PYTHON-480) +* db_field is not always respected with UpdateStatement (PYTHON-530) +* Sync_table fails on column.Set with secondary index (PYTHON-533) + 3.1.1 ===== March 14, 2016 Bug Fixes --------- - * cqlengine: Fix performance issue related to additional "COUNT" queries (PYTHON-522) 3.1.0 @@ -122,7 +1015,7 @@ Bug Fixes 2.7.2 ===== -Setpember 14, 2015 +September 14, 2015 Bug Fixes --------- diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 6d39d8df69..f71ebabdbb 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -5,7 +5,8 @@ Contributions are welcome in the form of bug reports or pull requests. Bug Reports ----------- -Quality bug reports are welcome at the `DataStax Python Drvier JIRA `_. +Quality bug reports are welcome at the `CASSPYTHON project `_ +of the ASF JIRA. There are plenty of `good resources `_ describing how to create good bug reports. They will not be repeated in detail here, but in general, the bug report include where appropriate: @@ -17,16 +18,11 @@ good bug reports. They will not be repeated in detail here, but in general, the Pull Requests ------------- -If you're able to fix a bug yourself, you can [fork the repository](https://help.github.com/articles/fork-a-repo/) and submit a [Pull Request](https://help.github.com/articles/using-pull-requests/) with the fix. -Please include tests demonstrating the issue and fix. For examples of how to run the tests, consult the `dev README `_. - -Contribution License Agreement ------------------------------- -To protect the community, all contributors are required to [sign the DataStax Contribution License Agreement](http://cla.datastax.com/). The process is completely electronic and should only take a few minutes. +If you're able to fix a bug yourself, you can `fork the repository `_ and submit a `Pull Request `_ with the fix. +Please include tests demonstrating the issue and fix. For examples of how to run the tests, consult the `dev README `_. Design and Implementation Guidelines ------------------------------------ -- We support Python 2.6+, so any changes must work in any of these runtimes (we use ``six``, ``futures``, and some internal backports for compatability) - We have integrations (notably Cassandra cqlsh) that require pure Python and minimal external dependencies. We try to avoid new external dependencies. Where compiled extensions are concerned, there should always be a pure Python fallback implementation. - This project follows `semantic versioning `_, so breaking API changes will only be introduced in major versions. - Legacy ``cqlengine`` has varying degrees of overreaching client-side validation. Going forward, we will avoid client validation where server feedback is adequate and not overly expensive. diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000000..dc70acc6c5 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,746 @@ +#!groovy +/* + +There are multiple combinations to test the python driver. + +Test Profiles: + + Full: Execute all unit and integration tests, including long tests. + Standard: Execute unit and integration tests. + Smoke Tests: Execute a small subset of tests. + EVENT_LOOP: Execute a small subset of tests selected to test EVENT_LOOPs. + +Matrix Types: + + Full: All server versions, python runtimes tested with and without Cython. + Cassandra: All cassandra server versions. + Dse: All dse server versions. + Hcd: All hcd server versions. + Smoke: CI-friendly configurations. Currently-supported Python version + modern Cassandra/DSE instances. + We also avoid cython since it's tested as part of the nightlies + +Parameters: + + EVENT_LOOP: 'LIBEV' (Default), 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED' + CYTHON: Default, 'True', 'False' + +*/ + +@Library('dsdrivers-pipeline-lib@develop') +import com.datastax.jenkins.drivers.python.Slack + +slack = new Slack() + +DEFAULT_CASSANDRA = ['3.11', '4.0', '4.1', '5.0'] +DEFAULT_DSE = ['dse-5.1.35', 'dse-6.8.30', 'dse-6.9.0'] +DEFAULT_HCD = ['hcd-1.0.0'] +DEFAULT_RUNTIME = ['3.9.23', '3.10.18', '3.11.13', '3.12.11', '3.13.5'] +DEFAULT_CYTHON = ["True", "False"] +matrices = [ + "FULL": [ + "SERVER": DEFAULT_CASSANDRA + DEFAULT_DSE, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "CASSANDRA": [ + "SERVER": DEFAULT_CASSANDRA, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "DSE": [ + "SERVER": DEFAULT_DSE, + "RUNTIME": DEFAULT_RUNTIME, + "CYTHON": DEFAULT_CYTHON + ], + "SMOKE": [ + "SERVER": DEFAULT_CASSANDRA.takeRight(2) + DEFAULT_DSE.takeRight(2) + DEFAULT_HCD.takeRight(1), + "RUNTIME": DEFAULT_RUNTIME.take(1) + DEFAULT_RUNTIME.takeRight(1), + "CYTHON": ["True"] + ] +] + +def initializeSlackContext() { + /* + Based on git branch/commit, configure the build context and env vars. + */ + + def driver_display_name = 'Cassandra Python Driver' + if (env.GIT_URL.contains('riptano/python-driver')) { + driver_display_name = 'private ' + driver_display_name + } else if (env.GIT_URL.contains('python-dse-driver')) { + driver_display_name = 'DSE Python Driver' + } + env.DRIVER_DISPLAY_NAME = driver_display_name + env.GIT_SHA = "${env.GIT_COMMIT.take(7)}" + env.GITHUB_PROJECT_URL = "https://${GIT_URL.replaceFirst(/(git@|http:\/\/|https:\/\/)/, '').replace(':', '/').replace('.git', '')}" + env.GITHUB_BRANCH_URL = "${env.GITHUB_PROJECT_URL}/tree/${env.BRANCH_NAME}" + env.GITHUB_COMMIT_URL = "${env.GITHUB_PROJECT_URL}/commit/${env.GIT_COMMIT}" +} + +def getBuildContext() { + /* + Based on schedule and parameters, configure the build context and env vars. + */ + + def PROFILE = "${params.PROFILE}" + def EVENT_LOOP = "${params.EVENT_LOOP.toLowerCase()}" + + matrixType = params.MATRIX != "DEFAULT" ? params.MATRIX : "SMOKE" + matrix = matrices[matrixType].clone() + + // Check if parameters were set explicitly + if (params.CYTHON != "DEFAULT") { + matrix["CYTHON"] = [params.CYTHON] + } + + if (params.SERVER_VERSION != "DEFAULT") { + matrix["SERVER"] = [params.SERVER_VERSION] + } + + if (params.PYTHON_VERSION != "DEFAULT") { + matrix["RUNTIME"] = [params.PYTHON_VERSION] + } + + if (params.CI_SCHEDULE == "WEEKNIGHTS") { + matrix["SERVER"] = params.CI_SCHEDULE_SERVER_VERSION.split(' ') + matrix["RUNTIME"] = params.CI_SCHEDULE_PYTHON_VERSION.split(' ') + } + + context = [ + vars: [ + "PROFILE=${PROFILE}", + "EVENT_LOOP=${EVENT_LOOP}" + ], + matrix: matrix + ] + + return context +} + +def buildAndTest(context) { + initializeEnvironment() + installDriverAndCompileExtensions() + + try { + executeTests() + } finally { + junit testResults: '*_results.xml' + } +} + +def getMatrixBuilds(buildContext) { + def tasks = [:] + matrix = buildContext.matrix + + matrix["SERVER"].each { serverVersion -> + matrix["RUNTIME"].each { runtimeVersion -> + matrix["CYTHON"].each { cythonFlag -> + def taskVars = [ + "CASSANDRA_VERSION=${serverVersion}", + "PYTHON_VERSION=${runtimeVersion}", + "CYTHON_ENABLED=${cythonFlag}" + ] + def cythonDesc = cythonFlag == "True" ? ", Cython": "" + tasks["${serverVersion}, py${runtimeVersion}${cythonDesc}"] = { + node("${OS_VERSION}") { + scm_variables = checkout scm + env.GIT_COMMIT = scm_variables.get('GIT_COMMIT') + env.GIT_URL = scm_variables.get('GIT_URL') + initializeSlackContext() + + if (env.BUILD_STATED_SLACK_NOTIFIED != 'true') { + slack.notifyChannel() + } + + withEnv(taskVars) { + buildAndTest(context) + } + } + } + } + } + } + return tasks +} + +def initializeEnvironment() { + sh label: 'Initialize the environment', script: '''#!/bin/bash -lex + pyenv global ${PYTHON_VERSION} + sudo apt-get install socat + pip install --upgrade pip + pip install -U setuptools + + # install a version of pyyaml<6.0 compatible with ccm-3.1.5 as of Aug 2023 + # this works around the python-3.10+ compatibility problem as described in DSP-23524 + pip install wheel + pip install "Cython<3.0" "pyyaml<6.0" --no-build-isolation + pip install ${HOME}/ccm + ''' + + // Determine if server version is Apache CassandraⓇ or DataStax Enterprise + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') { + if (env.PYTHON_VERSION =~ /3\.12\.\d+/) { + echo "Cannot install DSE dependencies for Python 3.12.x; installing Apache CassandraⓇ requirements only. See PYTHON-1368 for more detail." + sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex + pip install -r test-requirements.txt + ''' + } + else { + sh label: 'Install DataStax Enterprise requirements', script: '''#!/bin/bash -lex + pip install -r test-datastax-requirements.txt + ''' + } + } else { + sh label: 'Install Apache CassandraⓇ requirements', script: '''#!/bin/bash -lex + pip install -r test-requirements.txt + ''' + + sh label: 'Uninstall the geomet dependency since it is not required for Cassandra', script: '''#!/bin/bash -lex + pip uninstall -y geomet + ''' + } + + sh label: 'Install unit test modules', script: '''#!/bin/bash -lex + pip install --no-deps nose-ignore-docstring nose-exclude + pip install service_identity + ''' + + if (env.CYTHON_ENABLED == 'True') { + sh label: 'Install cython modules', script: '''#!/bin/bash -lex + pip install cython numpy + ''' + } + + sh label: 'Download Apache CassandraⓇ or DataStax Enterprise', script: '''#!/bin/bash -lex + . ${CCM_ENVIRONMENT_SHELL} ${CASSANDRA_VERSION} + ''' + + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse') { + env.DSE_FIXED_VERSION = env.CASSANDRA_VERSION.split('-')[1] + sh label: 'Update environment for DataStax Enterprise', script: '''#!/bin/bash -le + cat >> ${HOME}/environment.txt << ENVIRONMENT_EOF +CCM_CASSANDRA_VERSION=${DSE_FIXED_VERSION} # maintain for backwards compatibility +CCM_VERSION=${DSE_FIXED_VERSION} +CCM_SERVER_TYPE=dse +DSE_VERSION=${DSE_FIXED_VERSION} +CCM_IS_DSE=true +CCM_BRANCH=${DSE_FIXED_VERSION} +DSE_BRANCH=${DSE_FIXED_VERSION} +ENVIRONMENT_EOF + ''' + } else if (env.CASSANDRA_VERSION.split('-')[0] == 'hcd') { + env.HCD_FIXED_VERSION = env.CASSANDRA_VERSION.split('-')[1] + sh label: 'Update environment for DataStax Enterprise', script: '''#!/bin/bash -le + cat >> ${HOME}/environment.txt << ENVIRONMENT_EOF +CCM_CASSANDRA_VERSION=${HCD_FIXED_VERSION} # maintain for backwards compatibility +CCM_VERSION=${HCD_FIXED_VERSION} +CCM_SERVER_TYPE=hcd +HCD_VERSION=${HCD_FIXED_VERSION} +CCM_IS_HCD=true +CCM_BRANCH=${HCD_FIXED_VERSION} +HCD_BRANCH=${HCD_FIXED_VERSION} +ENVIRONMENT_EOF + ''' + } + + sh label: 'Display Python and environment information', script: '''#!/bin/bash -le + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + python --version + pip --version + pip freeze + printenv | sort + ''' +} + +def installDriverAndCompileExtensions() { + if (env.CYTHON_ENABLED == 'True') { + sh label: 'Install the driver and compile with C extensions with Cython', script: '''#!/bin/bash -lex + python setup.py build_ext --inplace + ''' + } else { + sh label: 'Install the driver and compile with C extensions without Cython', script: '''#!/bin/bash -lex + python setup.py build_ext --inplace --no-cython + ''' + } +} + +def executeStandardTests() { + + try { + sh label: 'Execute unit tests', script: '''#!/bin/bash -lex + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + failure=0 + EVENT_LOOP=${EVENT_LOOP} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_results.xml tests/unit/ || failure=1 + EVENT_LOOP_MANAGER=eventlet VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_eventlet_results.xml tests/unit/io/test_eventletreactor.py || failure=1 + EVENT_LOOP_MANAGER=gevent VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=unit_gevent_results.xml tests/unit/io/test_geventreactor.py || failure=1 + exit $failure + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute Simulacron integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + failure=0 + SIMULACRON_JAR="${HOME}/simulacron.jar" + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_results.xml tests/integration/simulacron/ || true + + # Run backpressure tests separately to avoid memory issue + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_1_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_paused_connections || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_2_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_queued_requests_timeout || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_3_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_cluster_busy || failure=1 + SIMULACRON_JAR=${SIMULACRON_JAR} EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=test_backpressure.py --junit-xml=simulacron_backpressure_4_results.xml tests/integration/simulacron/test_backpressure.py:TCPBackpressureTests.test_node_busy || failure=1 + exit $failure + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute CQL engine integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=cqle_results.xml tests/integration/cqlengine/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + try { + sh label: 'Execute Apache CassandraⓇ integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variables + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml tests/integration/standard/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + if (env.CASSANDRA_VERSION.split('-')[0] == 'dse' && env.CASSANDRA_VERSION.split('-')[1] != '4.8') { + if (env.PYTHON_VERSION =~ /3\.12\.\d+/) { + echo "Cannot install DSE dependencies for Python 3.12.x. See PYTHON-1368 for more detail." + } + else { + try { + sh label: 'Execute DataStax Enterprise integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CASSANDRA_DIR=${CCM_INSTALL_DIR} DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} ADS_HOME="${HOME}/" VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=dse_results.xml tests/integration/advanced/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + } + } + + try { + sh label: 'Execute DataStax Astra integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CLOUD_PROXY_PATH="${HOME}/proxy/" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=advanced_results.xml tests/integration/cloud/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + + if (env.PROFILE == 'FULL') { + try { + sh label: 'Execute long running integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --ignore=tests/integration/long/upgrade --junit-xml=long_results.xml tests/integration/long/ + ''' + } catch (err) { + currentBuild.result = 'UNSTABLE' + } + } +} + +def executeDseSmokeTests() { + sh label: 'Execute profile DataStax Enterprise smoke test integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} DSE_VERSION=${DSE_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml tests/integration/standard/test_dse.py + ''' +} + +def executeEventLoopTests() { + sh label: 'Execute profile event loop manager integration tests', script: '''#!/bin/bash -lex + # Load CCM environment variable + set -o allexport + . ${HOME}/environment.txt + set +o allexport + + . ${JABBA_SHELL} + jabba use 1.8 + + EVENT_LOOP_TESTS=( + "tests/integration/standard/test_cluster.py" + "tests/integration/standard/test_concurrent.py" + "tests/integration/standard/test_connection.py" + "tests/integration/standard/test_control_connection.py" + "tests/integration/standard/test_metrics.py" + "tests/integration/standard/test_query.py" + "tests/integration/simulacron/test_endpoint.py" + "tests/integration/long/test_ssl.py" + ) + EVENT_LOOP=${EVENT_LOOP} CCM_ARGS="${CCM_ARGS}" DSE_VERSION=${DSE_VERSION} HCD_VERSION=${HCD_VERSION} CASSANDRA_VERSION=${CCM_CASSANDRA_VERSION} MAPPED_CASSANDRA_VERSION=${MAPPED_CASSANDRA_VERSION} VERIFY_CYTHON=${CYTHON_ENABLED} JVM_EXTRA_OPTS="$JVM_EXTRA_OPTS -Xss384k" pytest -s -v --log-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --junit-xml=standard_results.xml ${EVENT_LOOP_TESTS[@]} + ''' +} + +def executeTests() { + switch(env.PROFILE) { + case 'DSE-SMOKE-TEST': + executeDseSmokeTests() + break + case 'EVENT_LOOP': + executeEventLoopTests() + break + default: + executeStandardTests() + break + } +} + + +// TODO move this in the shared lib +def getDriverMetricType() { + metric_type = 'oss' + if (env.GIT_URL.contains('riptano/python-driver')) { + metric_type = 'oss-private' + } else if (env.GIT_URL.contains('python-dse-driver')) { + metric_type = 'dse' + } + return metric_type +} + +def describeBuild(buildContext) { + script { + def runtimes = buildContext.matrix["RUNTIME"] + def serverVersions = buildContext.matrix["SERVER"] + def numBuilds = runtimes.size() * serverVersions.size() * buildContext.matrix["CYTHON"].size() + currentBuild.displayName = "${env.PROFILE} (${env.EVENT_LOOP} | ${numBuilds} builds)" + currentBuild.description = "${env.PROFILE} build testing servers (${serverVersions.join(', ')}) against Python (${runtimes.join(', ')}) using ${env.EVENT_LOOP} event loop manager" + } +} + +// branch pattern for cron +def branchPatternCron() { + ~"(master)" +} + +pipeline { + agent none + + // Global pipeline timeout + options { + disableConcurrentBuilds() + timeout(time: 10, unit: 'HOURS') // TODO timeout should be per build + buildDiscarder(logRotator(artifactNumToKeepStr: '10', // Keep only the last 10 artifacts + numToKeepStr: '50')) // Keep only the last 50 build records + } + + parameters { + choice( + name: 'ADHOC_BUILD_TYPE', + choices: ['BUILD', 'BUILD-AND-EXECUTE-TESTS'], + description: '''

Perform a adhoc build operation

+ + + + + + + + + + + + + + + +
ChoiceDescription
BUILDPerforms a Per-Commit build
BUILD-AND-EXECUTE-TESTSPerforms a build and executes the integration and unit tests
''') + choice( + name: 'PROFILE', + choices: ['STANDARD', 'FULL', 'DSE-SMOKE-TEST', 'EVENT_LOOP'], + description: '''

Profile to utilize for scheduled or adhoc builds

+ + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
STANDARDExecute the standard tests for the driver
FULLExecute all tests for the driver, including long tests.
DSE-SMOKE-TESTExecute only the DataStax Enterprise smoke tests
EVENT_LOOPExecute only the event loop tests for the specified event loop manager (see: EVENT_LOOP)
''') + choice( + name: 'MATRIX', + choices: ['DEFAULT', 'SMOKE', 'FULL', 'CASSANDRA', 'DSE', 'HCD'], + description: '''

The matrix for the build.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DEFAULTDefault to the build context.
SMOKEBasic smoke tests for current Python runtimes + C*/DSE versions, no Cython
FULLAll server versions, python runtimes tested with and without Cython.
CASSANDRAAll cassandra server versions.
DSEAll dse server versions.
HCDAll hcd server versions.
''') + choice( + name: 'PYTHON_VERSION', + choices: ['DEFAULT'] + DEFAULT_RUNTIME, + description: 'Python runtime version. Default to the build context.') + choice( + name: 'SERVER_VERSION', + choices: ['DEFAULT'] + DEFAULT_CASSANDRA + DEFAULT_DSE + DEFAULT_HCD, + description: '''Apache CassandraⓇ and DataStax Enterprise server version to use for adhoc BUILD-AND-EXECUTE-TESTS ONLY! + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DEFAULTDefault to the build context.
3.0Apache CassandraⓇ v3.0.x
3.11Apache CassandraⓇ v3.11.x
4.0Apache CassandraⓇ v4.0.x
5.0Apache CassandraⓇ v5.0.x
dse-5.1.35DataStax Enterprise v5.1.x
dse-6.8.30DataStax Enterprise v6.8.x
dse-6.9.0DataStax Enterprise v6.9.x (CURRENTLY UNDER DEVELOPMENT)
hcd-1.0.0DataStax HCD v1.0.x (CURRENTLY UNDER DEVELOPMENT)
''') + choice( + name: 'CYTHON', + choices: ['DEFAULT'] + DEFAULT_CYTHON, + description: '''

Flag to determine if Cython should be enabled

+ + + + + + + + + + + + + + + + + + + +
ChoiceDescription
DefaultDefault to the build context.
TrueEnable Cython
FalseDisable Cython
''') + choice( + name: 'EVENT_LOOP', + choices: ['LIBEV', 'GEVENT', 'EVENTLET', 'ASYNCIO', 'ASYNCORE', 'TWISTED'], + description: '''

Event loop manager to utilize for scheduled or adhoc builds

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ChoiceDescription
LIBEVA full-featured and high-performance event loop that is loosely modeled after libevent, but without its limitations and bugs
GEVENTA co-routine -based Python networking library that uses greenlet to provide a high-level synchronous API on top of the libev or libuv event loop
EVENTLETA concurrent networking library for Python that allows you to change how you run your code, not how you write it
ASYNCIOA library to write concurrent code using the async/await syntax
ASYNCOREA module provides the basic infrastructure for writing asynchronous socket service clients and servers
TWISTEDAn event-driven networking engine written in Python and licensed under the open source MIT license
''') + choice( + name: 'CI_SCHEDULE', + choices: ['DO-NOT-CHANGE-THIS-SELECTION', 'WEEKNIGHTS', 'WEEKENDS'], + description: 'CI testing schedule to execute periodically scheduled builds and tests of the driver (DO NOT CHANGE THIS SELECTION)') + string( + name: 'CI_SCHEDULE_PYTHON_VERSION', + defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION', + description: 'CI testing python version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)') + string( + name: 'CI_SCHEDULE_SERVER_VERSION', + defaultValue: 'DO-NOT-CHANGE-THIS-SELECTION', + description: 'CI testing server version to utilize for scheduled test runs of the driver (DO NOT CHANGE THIS SELECTION)') + } + + triggers { + parameterizedCron(branchPatternCron().matcher(env.BRANCH_NAME).matches() ? """ + # Every weeknight (Monday - Friday) around 4:00 AM + # These schedules will run with and without Cython enabled for Python 3.9.23 and 3.13.5 + H 4 * * 1-5 %CI_SCHEDULE=WEEKNIGHTS;EVENT_LOOP=LIBEV;CI_SCHEDULE_PYTHON_VERSION=3.9.23 3.13.5;CI_SCHEDULE_SERVER_VERSION=3.11 4.0 5.0 dse-5.1.35 dse-6.8.30 dse-6.9.0 hcd-1.0.0 + """ : "") + } + + environment { + OS_VERSION = 'ubuntu/focal64/python-driver' + CCM_ENVIRONMENT_SHELL = '/usr/local/bin/ccm_environment.sh' + CCM_MAX_HEAP_SIZE = '1536M' + JABBA_SHELL = '/usr/lib/jabba/jabba.sh' + } + + stages { + stage ('Build and Test') { + when { + beforeAgent true + allOf { + not { buildingTag() } + } + } + + steps { + script { + context = getBuildContext() + withEnv(context.vars) { + describeBuild(context) + + // build and test all builds + parallel getMatrixBuilds(context) + + slack.notifyChannel(currentBuild.currentResult) + } + } + } + } + + } +} diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000000..58250f616b --- /dev/null +++ b/NOTICE @@ -0,0 +1,115 @@ +Apache Cassandra Python Driver +Copyright 2013 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +This product originates, before git sha +5d4fd2349119a3237ad351a96e7f2b3317159305, from software from DataStax and other +individual contributors. All work was previously copyrighted to DataStax. + +Non-DataStax contributors are listed below. Those marked with asterisk have +explicitly consented to their contributions being donated to the ASF. + +a-detiste Alexandre Detiste * +a-lst Andrey Istochkin * +aboudreault Alan Boudreault alan@alanb.ca +advance512 Alon Diamant diamant.alon@gmail.com * +alanjds Alan Justino da Silva alan.justino@yahoo.com.br * +alistair-broomhead Alistair Broomhead * +amygdalama Amy Hanlon * +andy-slac Andy Salnikov * +andy8zhao Andy Zhao +anthony-cervantes Anthony Cervantes anthony@cervantes.io * +BackEndTea Gert de Pagter * +barvinograd Bar Vinograd +bbirand Berk Birand +bergundy Roey Berman roey.berman@gmail.com * +bohdantan * +codesnik Alexey Trofimenko aronaxis@gmail.com * +coldeasy coldeasy +DanieleSalatti Daniele Salatti me@danielesalatti.com * +daniloarodrigues Danilo de Araújo Rodrigues * +daubman Aaron Daubman github@ajd.us * +dcosson Danny Cosson dcosson@gmail.com * +detzgk Eli Green eli@zigr.org * +devdazed Russ Bradberry * +dizpers Dmitry Belaventsev dizpers@gmail.com +dkropachev Dmitry Kropachev dmitry.kropachev@gmail.com * +dmglab Daniel dmg.lab@outlook.com +dokai Kai Lautaportti * +eamanu Emmanuel Arias eamanu@yaerobi.com +figpope Andrew FigPope andrew.figpope@gmail.com * +flupke Luper Rouch * +frensjan Frens Jan Rumph * +frew Fred Wulff frew@cs.stanford.edu * +gdoermann Greg Doermann +haaawk Piotr Jastrzębski +ikapl Irina Kaplounova +ittus Thang Minh Vu * +JeremyOT Jeremy Olmsted-Thompson * +jeremyschlatter Jeremy Schlatter * +jpuerta Ernesto Puerta * +julien-duponchelle Julien Duponchelle julien@duponchelle.info * +justinsb Justin Santa Barbara justinsb@google.com * +Kami Tomaz Muraus tomaz@tomaz.me +kandul Michał Kandulski michal.kandulski@gmail.com +kdeldycke Kevin Deldycke * +kishkaru Kishan Karunaratne kishan@karu.io * +kracekumar Kracekumar kracethekingmaker@gmail.com +lenards Andrew Lenards andrew.lenards@gmail.com * +lenolib +Lifto Ellis Low +Lorak-mmk Karol Baryła git@baryla.org * +lukaselmer Lukas Elmer lukas.elmer@gmail.com * +mahall Michael Hall +markflorisson Mark Florisson * +mattrobenolt Matt Robenolt m@robenolt.com * +mattstibbs Matt Stibbs * +Mhs-220 Mo Shahmohammadi hos1377@gmail.com * +mikeokner Mike Okner * +Mishail Mikhail Stepura mstepura@apple.com * +mission-liao mission.liao missionaryliao@gmail.com * +mkocikowski Mik Kocikowski +Mokto Théo Mathieu * +mrk-its Mariusz Kryński * +multani Jonathan Ballet jon@multani.info * +niklaskorz Niklas Korz * +nisanharamati nisanharamati +nschrader Nick Schrader nick.schrader@mailbox.org * +Orenef11 Oren Efraimov * +oz123 Oz Tiram * +pistolero Sergii Kyryllov * +pmcnett Paul McNett p@ulmcnett.com * +psarna Piotr Sarna * +r4fek Rafał Furmański * +raopm +rbranson Rick Branson * +rqx Roman Khanenko * +rtb-zla-karma xyz * +sigmunau +silviot Silvio Tomatis +sontek John Anderson sontek@gmail.com * +stanhu Stan Hu +stefanor Stefano Rivera stefanor@debian.org * +strixcuriosus Ash Hoover strixcuriosus@gmail.com +tarzanjw Học Đỗ hoc3010@gmail.com +tbarbugli Tommaso Barbugli +tchaikov Kefu Chai tchaikov@gmail.com * +tglines Travis Glines +thoslin Tom Lin +tigrus Nikolay Fominykh nikolayfn@gmail.com +timgates42 Tim Gates +timsavage Tim Savage * +tirkarthi Karthikeyan Singaravelan tir.karthi@gmail.com * +Trundle Andreas Stührk andy@hammerhartes.de +ubombi Vitalii Kozlovskyi vitalii@kozlovskyi.dev * +ultrabug Ultrabug * +vetal4444 Shevchenko Vitaliy * +victorpoluceno Victor Godoy Poluceno victorpoluceno@gmail.com +weisslj Johannes Weißl * +wenheping wenheping wenheping2000@hotmail.com +yi719 +yinyin Yinyin * +yriveiro Yago Riveiro * \ No newline at end of file diff --git a/README-dev.rst b/README-dev.rst index deb2666fb7..939d3fa480 100644 --- a/README-dev.rst +++ b/README-dev.rst @@ -1,20 +1,37 @@ Releasing ========= +Note: the precise details of some of these steps have changed. Leaving this here as a guide only. + * Run the tests and ensure they all pass * Update CHANGELOG.rst + * Check for any missing entries + * Add today's date to the release section * Update the version in ``cassandra/__init__.py`` - * For beta releases, use a version like ``(2, 1, '0b1')`` * For release candidates, use a version like ``(2, 1, '0rc1')`` * When in doubt, follow PEP 440 versioning - -* Commit the changelog and version changes +* Add the new version in ``docs.yaml`` +* Commit the changelog and version changes, e.g. ``git commit -m'version 1.0.0'`` * Tag the release. For example: ``git tag -a 1.0.0 -m 'version 1.0.0'`` -* Push the commit and tag: ``git push --tags origin master`` -* Upload the package to pypi:: +* Push the tag and new ``master``: ``git push origin 1.0.0 ; git push origin master`` +* Update the `python-driver` submodule of `python-driver-wheels`, + commit then push. +* Trigger the Github Actions necessary to build wheels for the various platforms +* For a GA release, upload the package to pypi:: + + # Clean the working directory + python setup.py clean + rm dist/* + + # Build the source distribution + python setup.py sdist - python setup.py register - python setup.py sdist upload + # Download all wheels from the jfrog repository and copy them in + # the dist/ directory + cp /path/to/wheels/*.whl dist/ + + # Upload all files + twine upload dist/* * On pypi, make the latest GA the only visible version * Update the docs (see below) @@ -23,108 +40,74 @@ Releasing * After a beta or rc release, this should look like ``(2, 1, '0b1', 'post0')`` +* After the release has been tagged, add a section to docs.yaml with the new tag ref:: + + versions: + - name: + ref: + * Commit and push * Update 'cassandra-test' branch to reflect new release - + * this is typically a matter of merging or rebasing onto master * test and push updated branch to origin -* Update the JIRA versions: https://datastax-oss.atlassian.net/plugins/servlet/project-config/PYTHON/versions -* Make an announcement on the mailing list - -Building the Docs -================= -Sphinx is required to build the docs. You probably want to install through apt, -if possible:: - - sudo apt-get install python-sphinx - -pip may also work:: - - sudo pip install -U Sphinx - -To build the docs, run:: - - python setup.py doc - -To upload the docs, checkout the ``gh-pages`` branch (it's usually easier to -clone a second copy of this repo and leave it on that branch) and copy the entire -contents all of ``docs/_build/X.Y.Z/*`` into the root of the ``gh-pages`` branch -and then push that branch to github. - -For example:: - - python setup.py doc - cp -R docs/_build/1.0.0-beta1/* ~/python-driver-docs/ - cd ~/python-driver-docs - git add --all - git commit -m 'Update docs' - git push origin gh-pages - -If docs build includes errors, those errors may not show up in the next build unless -you have changed the files with errors. It's good to occassionally clear the build -directory and build from scratch:: +* Update the JIRA releases: https://issues.apache.org/jira/projects/CASSPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin:release-page - rm -rf docs/_build/* + * add release dates and set version as "released" -Running the Tests -================= -In order for the extensions to be built and used in the test, run:: +* Make an announcement on the mailing list - python setup.py nosetests +Tests +===== -You can run a specific test module or package like so:: +Running Unit Tests +------------------ +Unit tests can be run like so:: - python setup.py nosetests -w tests/unit/ + pytest tests/unit/ You can run a specific test method like so:: - python setup.py nosetests -w tests/unit/test_connection.py:ConnectionTest.test_bad_protocol_version - -Seeing Test Logs in Real Time ------------------------------ -Sometimes it's useful to output logs for the tests as they run:: - - python setup.py nosetests -w tests/unit/ --nocapture --nologcapture + pytest tests/unit/test_connection.py::ConnectionTest::test_bad_protocol_version -Use tee to capture logs and see them on your terminal:: +Running Integration Tests +------------------------- +In order to run integration tests, you must specify a version to run using the ``CASSANDRA_VERSION`` or ``DSE_VERSION`` environment variable:: - python setup.py nosetests -w tests/unit/ --nocapture --nologcapture 2>&1 | tee test.log + CASSANDRA_VERSION=2.0.9 pytest tests/integration/standard -Specifying a Cassandra Version for Integration Tests ----------------------------------------------------- -You can specify a cassandra version with the ``CASSANDRA_VERSION`` environment variable:: +Or you can specify a cassandra directory (to test unreleased versions):: - CASSANDRA_VERSION=2.0.9 python setup.py nosetests -w tests/integration/standard + CASSANDRA_DIR=/path/to/cassandra pytest tests/integration/standard/ -You can also specify a cassandra directory (to test unreleased versions):: +Specifying the usage of an already running Cassandra cluster +------------------------------------------------------------ +The test will start the appropriate Cassandra clusters when necessary but if you don't want this to happen because a Cassandra cluster is already running the flag ``USE_CASS_EXTERNAL`` can be used, for example:: - CASSANDRA_DIR=/home/thobbs/cassandra python setup.py nosetests -w tests/integration/standard + USE_CASS_EXTERNAL=1 CASSANDRA_VERSION=2.0.9 pytest tests/integration/standard Specify a Protocol Version for Tests ------------------------------------ The protocol version defaults to 1 for cassandra 1.2 and 2 otherwise. You can explicitly set it with the ``PROTOCOL_VERSION`` environment variable:: - PROTOCOL_VERSION=3 python setup.py nosetests -w tests/integration/standard + PROTOCOL_VERSION=3 pytest tests/integration/standard Testing Multiple Python Versions -------------------------------- -If you want to test all of python 2.6, 2.7, and pypy, use tox (this is what -TravisCI runs):: +Use tox to test all of Python 3.9 through 3.13 and pypy:: tox -By default, tox only runs the unit tests because I haven't put in the effort -to get the integration tests to run on TravicCI. However, the integration -tests should work locally. To run them, edit the following line in tox.ini:: - - commands = {envpython} setup.py build_ext --inplace nosetests --verbosity=2 tests/unit/ - -and change ``tests/unit/`` to ``tests/``. +By default, tox only runs the unit tests. Running the Benchmarks ====================== +There needs to be a version of cassandra running locally so before running the benchmarks, if ccm is installed: + + ccm create benchmark_cluster -v 3.0.1 -n 1 -s + To run the benchmarks, pick one of the files under the ``benchmarks/`` dir and run it:: python benchmarks/future_batches.py @@ -146,3 +129,27 @@ name to specify the built version:: python setup.py egg_info -b-`git rev-parse --short HEAD` sdist --formats=zip The file (``dist/cassandra-driver-.zip``) is packaged with Cassandra in ``cassandra/lib/cassandra-driver-internal-only*zip``. + +Releasing an EAP +================ + +An EAP release is only uploaded on a private server and it is not published on pypi. + +* Clean the environment:: + + python setup.py clean + +* Package the source distribution:: + + python setup.py sdist + +* Test the source distribution:: + + pip install dist/cassandra-driver-.tar.gz + +* Upload the package on the EAP download server. +* Build the documentation:: + + python setup.py doc + +* Upload the docs on the EAP download server. diff --git a/README.rst b/README.rst index 88d50dabdb..47b5593ee9 100644 --- a/README.rst +++ b/README.rst @@ -1,28 +1,39 @@ -DataStax Python Driver for Apache Cassandra -=========================================== -.. image:: https://travis-ci.org/datastax/python-driver.png?branch=master - :target: https://travis-ci.org/datastax/python-driver +.. |license| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg + :target: https://opensource.org/licenses/Apache-2.0 +.. |version| image:: https://badge.fury.io/py/cassandra-driver.svg + :target: https://badge.fury.io/py/cassandra-driver +.. |pyversion| image:: https://img.shields.io/pypi/pyversions/cassandra-driver.svg +.. |travis| image:: https://api.travis-ci.com/datastax/python-driver.svg?branch=master + :target: https://travis-ci.com/github/datastax/python-driver -A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (1.2+) and DataStax Enterprise (3.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. +|license| |version| |pyversion| |travis| -The driver supports Python 2.6, 2.7, 3.3, and 3.4. +Apache Cassandra Python Driver +============================== -Feedback Requested ------------------- -**Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). +A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (2.1+) and +DataStax Enterprise (4.7+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. + +The driver supports Python 3.9 through 3.13. + +**Note:** DataStax products do not support big-endian systems. Features -------- -* `Synchronous `_ and `Asynchronous `_ APIs -* `Simple, Prepared, and Batch statements `_ +* `Synchronous `_ and `Asynchronous `_ APIs +* `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining -* `Connection pooling `_ +* `Connection pooling `_ * Automatic node discovery -* `Automatic reconnection `_ -* Configurable `load balancing `_ and `retry policies `_ -* `Concurrent execution utilities `_ -* `Object mapper `_ +* `Automatic reconnection `_ +* Configurable `load balancing `_ and `retry policies `_ +* `Concurrent execution utilities `_ +* `Object mapper `_ +* `Connecting to DataStax Astra database (cloud) `_ +* DSE Graph execution API +* DSE Geometric type serialization +* DSE PlainText and GSSAPI authentication Installation ------------ @@ -31,52 +42,45 @@ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the -`installation guide `_. +`installation guide `_. Documentation ------------- -The documentation can be found online `here `_. +The documentation can be found online `here `_. A couple of links for getting up to speed: -* `Installation `_ -* `Getting started guide `_ -* `API docs `_ -* `Performance tips `_ +* `Installation `_ +* `Getting started guide `_ +* `API docs `_ +* `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to -`documentation here `_. +`documentation here `_. Contributing ------------ -See `CONTRIBUTING.md `_. +See `CONTRIBUTING.rst `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the -`JIRA `_ issue tracker. +`CASSPYTHON project `_ +of the ASF JIRA. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ -Your two best options for getting help with the driver are the -`mailing list `_ -and the IRC channel. - -For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, -you can use `freenode's web-based client `_. - -Features to be Added --------------------- -* C extension for encoding/decoding messages +You can talk about the driver, ask questions and get help in the #cassandra-drivers channel on +`ASF Slack `_. License ------- -Copyright 2013-2016 DataStax +Copyright 2013 The Apache Software Foundation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 0000000000..12c43d57a0 --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,23 @@ +environment: + matrix: + - PYTHON: "C:\\Python38-x64" + cassandra_version: 3.11.2 + ci_type: standard +os: Visual Studio 2015 +platform: + - x64 +install: + - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" + - ps: .\appveyor\appveyor.ps1 +build_script: + - cmd: | + "%VS140COMNTOOLS%\..\..\VC\vcvarsall.bat" x86_amd64 + python setup.py install --no-cython +test_script: + - ps: .\appveyor\run_test.ps1 +cache: + - C:\Users\appveyor\.m2 + - C:\ProgramData\chocolatey\bin + - C:\ProgramData\chocolatey\lib + - C:\Users\appveyor\jce_policy-1.7.0.zip + - C:\Users\appveyor\jce_policy-1.8.0.zip \ No newline at end of file diff --git a/appveyor/appveyor.ps1 b/appveyor/appveyor.ps1 new file mode 100644 index 0000000000..5f6840e4e1 --- /dev/null +++ b/appveyor/appveyor.ps1 @@ -0,0 +1,80 @@ +$env:JAVA_HOME="C:\Program Files\Java\jdk1.8.0" +$env:PATH="$($env:JAVA_HOME)\bin;$($env:PATH)" +$env:CCM_PATH="C:\Users\appveyor\ccm" +$env:CASSANDRA_VERSION=$env:cassandra_version +$env:EVENT_LOOP_MANAGER="asyncore" +$env:SIMULACRON_JAR="C:\Users\appveyor\simulacron-standalone-0.7.0.jar" + +python --version +python -c "import platform; print(platform.architecture())" +# Install Ant +Start-Process cinst -ArgumentList @("-y","ant") -Wait -NoNewWindow +# Workaround for ccm, link ant.exe -> ant.bat +If (!(Test-Path C:\ProgramData\chocolatey\bin\ant.bat)) { + cmd /c mklink C:\ProgramData\chocolatey\bin\ant.bat C:\ProgramData\chocolatey\bin\ant.exe +} + + +$jce_indicator = "$target\README.txt" +# Install Java Cryptographic Extensions, needed for SSL. +If (!(Test-Path $jce_indicator)) { + $zip = "C:\Users\appveyor\jce_policy-$($env:java_version).zip" + $target = "$($env:JAVA_HOME)\jre\lib\security" + # If this file doesn't exist we know JCE hasn't been installed. + $url = "https://www.dropbox.com/s/po4308hlwulpvep/UnlimitedJCEPolicyJDK7.zip?dl=1" + $extract_folder = "UnlimitedJCEPolicy" + If ($env:java_version -eq "1.8.0") { + $url = "https://www.dropbox.com/s/al1e6e92cjdv7m7/jce_policy-8.zip?dl=1" + $extract_folder = "UnlimitedJCEPolicyJDK8" + } + # Download zip to staging area if it doesn't exist, we do this because + # we extract it to the directory based on the platform and we want to cache + # this file so it can apply to all platforms. + if(!(Test-Path $zip)) { + (new-object System.Net.WebClient).DownloadFile($url, $zip) + } + + Add-Type -AssemblyName System.IO.Compression.FileSystem + [System.IO.Compression.ZipFile]::ExtractToDirectory($zip, $target) + + $jcePolicyDir = "$target\$extract_folder" + Move-Item $jcePolicyDir\* $target\ -force + Remove-Item $jcePolicyDir +} + +# Download simulacron +$simulacron_url = "https://github.com/datastax/simulacron/releases/download/0.7.0/simulacron-standalone-0.7.0.jar" +$simulacron_jar = $env:SIMULACRON_JAR +if(!(Test-Path $simulacron_jar)) { + (new-object System.Net.WebClient).DownloadFile($simulacron_url, $simulacron_jar) +} + +# Install Python Dependencies for CCM. +Start-Process python -ArgumentList "-m pip install psutil pyYaml six numpy" -Wait -NoNewWindow + +# Clone ccm from git and use master. +If (!(Test-Path $env:CCM_PATH)) { + Start-Process git -ArgumentList "clone -b cassandra-test https://github.com/pcmanus/ccm.git $($env:CCM_PATH)" -Wait -NoNewWindow +} + + +# Copy ccm -> ccm.py so windows knows to run it. +If (!(Test-Path $env:CCM_PATH\ccm.py)) { + Copy-Item "$env:CCM_PATH\ccm" "$env:CCM_PATH\ccm.py" +} + +$env:PYTHONPATH="$($env:CCM_PATH);$($env:PYTHONPATH)" +$env:PATH="$($env:CCM_PATH);$($env:PATH)" + +# Predownload cassandra version for CCM if it isn't already downloaded. +# This is necessary because otherwise ccm fails +If (!(Test-Path C:\Users\appveyor\.ccm\repository\$env:cassandra_version)) { + Start-Process python -ArgumentList "$($env:CCM_PATH)\ccm.py create -v $($env:cassandra_version) -n 1 predownload" -Wait -NoNewWindow + echo "Checking status of download" + python $env:CCM_PATH\ccm.py status + Start-Process python -ArgumentList "$($env:CCM_PATH)\ccm.py remove predownload" -Wait -NoNewWindow + echo "Downloaded version $env:cassandra_version" +} + +Start-Process python -ArgumentList "-m pip install -r test-requirements.txt" -Wait -NoNewWindow +Start-Process python -ArgumentList "-m pip install nose-ignore-docstring" -Wait -NoNewWindow diff --git a/appveyor/run_test.ps1 b/appveyor/run_test.ps1 new file mode 100644 index 0000000000..9b8c23fd8b --- /dev/null +++ b/appveyor/run_test.ps1 @@ -0,0 +1,49 @@ +Set-ExecutionPolicy Unrestricted +Set-ExecutionPolicy -ExecutionPolicy Unrestricted -Scope Process -force +Set-ExecutionPolicy -ExecutionPolicy Unrestricted -Scope CurrentUser -force +Get-ExecutionPolicy -List +echo $env:Path +echo "JAVA_HOME: $env:JAVA_HOME" +echo "PYTHONPATH: $env:PYTHONPATH" +echo "Cassandra version: $env:CASSANDRA_VERSION" +echo "Simulacron jar: $env:SIMULACRON_JAR" +echo $env:ci_type +python --version +python -c "import platform; print(platform.architecture())" + +$wc = New-Object 'System.Net.WebClient' + +if($env:ci_type -eq 'unit'){ + echo "Running Unit tests" + pytest -s -v --junit-xml=unit_results.xml .\tests\unit + + $env:EVENT_LOOP_MANAGER="gevent" + pytest -s -v --junit-xml=unit_results.xml .\tests\unit\io\test_geventreactor.py + $env:EVENT_LOOP_MANAGER="eventlet" + pytest -s -v --junit-xml=unit_results.xml .\tests\unit\io\test_eventletreactor.py + $env:EVENT_LOOP_MANAGER="asyncore" + + echo "uploading unit results" + $wc.UploadFile("https://ci.appveyor.com/api/testresults/junit/$($env:APPVEYOR_JOB_ID)", (Resolve-Path .\unit_results.xml)) + +} + +if($env:ci_type -eq 'standard'){ + + echo "Running CQLEngine integration tests" + pytest -s -v --junit-xml=cqlengine_results.xml .\tests\integration\cqlengine + $cqlengine_tests_result = $lastexitcode + $wc.UploadFile("https://ci.appveyor.com/api/testresults/junit/$($env:APPVEYOR_JOB_ID)", (Resolve-Path .\cqlengine_results.xml)) + echo "uploading CQLEngine test results" + + echo "Running standard integration tests" + pytest -s -v --junit-xml=standard_results.xml .\tests\integration\standard + $integration_tests_result = $lastexitcode + $wc.UploadFile("https://ci.appveyor.com/api/testresults/junit/$($env:APPVEYOR_JOB_ID)", (Resolve-Path .\standard_results.xml)) + echo "uploading standard integration test results" +} + + +$exit_result = $unit_tests_result + $cqlengine_tests_result + $integration_tests_result + $simulacron_tests_result +echo "Exit result: $exit_result" +exit $exit_result diff --git a/benchmarks/base.py b/benchmarks/base.py index 812db42aaa..290ba28788 100644 --- a/benchmarks/base.py +++ b/benchmarks/base.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,6 +21,7 @@ from threading import Thread import time from optparse import OptionParser +import uuid from greplin import scales @@ -29,7 +32,6 @@ import cassandra from cassandra.cluster import Cluster from cassandra.io.asyncorereactor import AsyncoreConnection -from cassandra.policies import HostDistance log = logging.getLogger() handler = logging.StreamHandler() @@ -38,6 +40,16 @@ logging.getLogger('cassandra').setLevel(logging.WARN) +_log_levels = { + 'CRITICAL': logging.CRITICAL, + 'ERROR': logging.ERROR, + 'WARN': logging.WARNING, + 'WARNING': logging.WARNING, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG, + 'NOTSET': logging.NOTSET, +} + have_libev = False supported_reactors = [AsyncoreConnection] try: @@ -47,6 +59,14 @@ except ImportError as exc: pass +have_asyncio = False +try: + from cassandra.io.asyncioreactor import AsyncioConnection + have_asyncio = True + supported_reactors.append(AsyncioConnection) +except (ImportError, SyntaxError): + pass + have_twisted = False try: from cassandra.io.twistedreactor import TwistedConnection @@ -59,49 +79,65 @@ KEYSPACE = "testkeyspace" + str(int(time.time())) TABLE = "testtable" +COLUMN_VALUES = { + 'int': 42, + 'text': "'42'", + 'float': 42.0, + 'uuid': uuid.uuid4(), + 'timestamp': "'2016-02-03 04:05+0000'" +} + -def setup(hosts): +def setup(options): log.info("Using 'cassandra' package from %s", cassandra.__path__) - cluster = Cluster(hosts, protocol_version=1) - cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) + cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) try: session = cluster.connect() log.debug("Creating keyspace...") - session.execute(""" - CREATE KEYSPACE %s - WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } - """ % KEYSPACE) + try: + session.execute(""" + CREATE KEYSPACE %s + WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + """ % options.keyspace) + + log.debug("Setting keyspace...") + except cassandra.AlreadyExists: + log.debug("Keyspace already exists") - log.debug("Setting keyspace...") - session.set_keyspace(KEYSPACE) + session.set_keyspace(options.keyspace) log.debug("Creating table...") - session.execute(""" - CREATE TABLE %s ( + create_table_query = """ + CREATE TABLE {0} ( thekey text, - col1 text, - col2 text, - PRIMARY KEY (thekey, col1) - ) - """ % TABLE) + """ + for i in range(options.num_columns): + create_table_query += "col{0} {1},\n".format(i, options.column_type) + create_table_query += "PRIMARY KEY (thekey))" + + try: + session.execute(create_table_query.format(TABLE)) + except cassandra.AlreadyExists: + log.debug("Table already exists.") + finally: cluster.shutdown() -def teardown(hosts): - cluster = Cluster(hosts, protocol_version=1) - cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) +def teardown(options): + cluster = Cluster(options.hosts, schema_metadata_enabled=False, token_metadata_enabled=False) session = cluster.connect() - session.execute("DROP KEYSPACE " + KEYSPACE) + if not options.keep_data: + session.execute("DROP KEYSPACE " + options.keyspace) cluster.shutdown() def benchmark(thread_class): options, args = parse_options() for conn_class in options.supported_reactors: - setup(options.hosts) + setup(options) log.info("==== %s ====" % (conn_class.__name__,)) kwargs = {'metrics_enabled': options.enable_metrics, @@ -109,20 +145,30 @@ def benchmark(thread_class): if options.protocol_version: kwargs['protocol_version'] = options.protocol_version cluster = Cluster(options.hosts, **kwargs) - session = cluster.connect(KEYSPACE) + session = cluster.connect(options.keyspace) log.debug("Sleeping for two seconds...") time.sleep(2.0) - query = session.prepare(""" - INSERT INTO {table} (thekey, col1, col2) VALUES (?, ?, ?) - """.format(table=TABLE)) - values = ('key', 'a', 'b') + # Generate the query + if options.read: + query = "SELECT * FROM {0} WHERE thekey = '{{key}}'".format(TABLE) + else: + query = "INSERT INTO {0} (thekey".format(TABLE) + for i in range(options.num_columns): + query += ", col{0}".format(i) + + query += ") VALUES ('{key}'" + for i in range(options.num_columns): + query += ", {0}".format(COLUMN_VALUES[options.column_type]) + query += ")" + + values = None # we don't use that anymore. Keeping it in case we go back to prepared statements. per_thread = options.num_ops // options.threads threads = [] - log.debug("Beginning inserts...") + log.debug("Beginning {0}...".format('reads' if options.read else 'inserts')) start = time.time() try: for i in range(options.threads): @@ -142,7 +188,7 @@ def benchmark(thread_class): end = time.time() finally: cluster.shutdown() - teardown(options.hosts) + teardown(options) total = end - start log.info("Total time: %0.2fs" % total) @@ -180,6 +226,8 @@ def parse_options(): help='number of operations [default: %default]') parser.add_option('--asyncore-only', action='store_true', dest='asyncore_only', help='only benchmark with asyncore connections') + parser.add_option('--asyncio-only', action='store_true', dest='asyncio_only', + help='only benchmark with asyncio connections') parser.add_option('--libev-only', action='store_true', dest='libev_only', help='only benchmark with libev connections') parser.add_option('--twisted-only', action='store_true', dest='twisted_only', @@ -190,17 +238,34 @@ def parse_options(): help='logging level: debug, info, warning, or error') parser.add_option('-p', '--profile', action='store_true', dest='profile', help='Profile the run') - parser.add_option('--protocol-version', type='int', dest='protocol_version', + parser.add_option('--protocol-version', type='int', dest='protocol_version', default=4, help='Native protocol version to use') + parser.add_option('-c', '--num-columns', type='int', dest='num_columns', default=2, + help='Specify the number of columns for the schema') + parser.add_option('-k', '--keyspace', type='str', dest='keyspace', default=KEYSPACE, + help='Specify the keyspace name for the schema') + parser.add_option('--keep-data', action='store_true', dest='keep_data', default=False, + help='Keep the data after the benchmark') + parser.add_option('--column-type', type='str', dest='column_type', default='text', + help='Specify the column type for the schema (supported: int, text, float, uuid, timestamp)') + parser.add_option('--read', action='store_true', dest='read', default=False, + help='Read mode') + options, args = parser.parse_args() options.hosts = options.hosts.split(',') - log.setLevel(options.log_level.upper()) + level = options.log_level.upper() + try: + log.setLevel(_log_levels[level]) + except KeyError: + log.warning("Unknown log level specified: %s; specify one of %s", options.log_level, _log_levels.keys()) if options.asyncore_only: options.supported_reactors = [AsyncoreConnection] + elif options.asyncio_only: + options.supported_reactors = [AsyncioConnection] elif options.libev_only: if not have_libev: log.error("libev is not available") @@ -235,6 +300,9 @@ def start_profile(self): if self.profiler: self.profiler.enable() + def run_query(self, key, **kwargs): + return self.session.execute_async(self.query.format(key=key), **kwargs) + def finish_profile(self): if self.profiler: self.profiler.disable() diff --git a/benchmarks/callback_full_pipeline.py b/benchmarks/callback_full_pipeline.py index 3736991b4e..5eafa5df8b 100644 --- a/benchmarks/callback_full_pipeline.py +++ b/benchmarks/callback_full_pipeline.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,7 +20,6 @@ from threading import Event from base import benchmark, BenchmarkThread -from six.moves import range log = logging.getLogger(__name__) @@ -41,8 +42,10 @@ def insert_next(self, previous_result=sentinel): if next(self.num_finished) >= self.num_queries: self.event.set() - if next(self.num_started) <= self.num_queries: - future = self.session.execute_async(self.query, self.values, timeout=None) + i = next(self.num_started) + if i <= self.num_queries: + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key, timeout=None) future.add_callbacks(self.insert_next, self.insert_next) def run(self): diff --git a/benchmarks/future_batches.py b/benchmarks/future_batches.py index 91c250bcf9..112cc24981 100644 --- a/benchmarks/future_batches.py +++ b/benchmarks/future_batches.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +16,7 @@ import logging from base import benchmark, BenchmarkThread -from six.moves import queue +import queue log = logging.getLogger(__name__) @@ -35,7 +37,8 @@ def run(self): except queue.Empty: break - future = self.session.execute_async(self.query, self.values) + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key) futures.put_nowait(future) while True: diff --git a/benchmarks/future_full_pipeline.py b/benchmarks/future_full_pipeline.py index 40682e0418..ca95b742d2 100644 --- a/benchmarks/future_full_pipeline.py +++ b/benchmarks/future_full_pipeline.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +16,7 @@ import logging from base import benchmark, BenchmarkThread -from six.moves import queue +import queue log = logging.getLogger(__name__) @@ -31,7 +33,8 @@ def run(self): old_future = futures.get_nowait() old_future.result() - future = self.session.execute_async(self.query, self.values) + key = "{}-{}".format(self.thread_num, i) + future = self.run_query(key) futures.put_nowait(future) while True: diff --git a/benchmarks/future_full_throttle.py b/benchmarks/future_full_throttle.py index 27d87442bb..f85eb99b0d 100644 --- a/benchmarks/future_full_throttle.py +++ b/benchmarks/future_full_throttle.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -25,8 +27,9 @@ def run(self): self.start_profile() - for _ in range(self.num_queries): - future = self.session.execute_async(self.query, self.values) + for i in range(self.num_queries): + key = "{0}-{1}".format(self.thread_num, i) + future = self.run_query(key) futures.append(future) for future in futures: diff --git a/benchmarks/sync.py b/benchmarks/sync.py index 531e41fbe8..090a265579 100644 --- a/benchmarks/sync.py +++ b/benchmarks/sync.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +15,6 @@ # limitations under the License. from base import benchmark, BenchmarkThread -from six.moves import range class Runner(BenchmarkThread): diff --git a/build.yaml b/build.yaml deleted file mode 100644 index 1d3c93915e..0000000000 --- a/build.yaml +++ /dev/null @@ -1,49 +0,0 @@ -python: - - 2.7 - - 3.4 -os: - - ubuntu/trusty64 -cassandra: - - 2.0 - - 2.1 - - 2.2 - - 3.0 - - 3.4 -env: - EVENT_LOOP_MANAGER: - - libev - CYTHON: - - CYTHON - - NO_CYTHON -build: - - script: | - export JAVA_HOME=$CCM_JAVA_HOME - export PATH=$JAVA_HOME/bin:$PATH - - # Install dependencies - if [[ $EVENT_LOOP_MANAGER == 'libev' ]]; then - sudo apt-get install -y libev4 libev-dev - fi - pip install -r test-requirements.txt - pip install nose-ignore-docstring - - if [[ $CYTHON == 'CYTHON' ]]; then - pip install cython - pip install numpy - # Install the driver & compile C extensions - python setup.py build_ext --inplace - else - # Install the driver & compile C extensions with no cython - python setup.py build_ext --inplace --no-cython - fi - - echo "==========RUNNING CQLENGINE TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=cqle_results.xml tests/integration/cqlengine/ || true - - echo "==========RUNNING INTEGRATION TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=standard_results.xml tests/integration/standard/ || true - - echo "==========RUNNING LONG INTEGRATION TESTS==========" - CASSANDRA_VERSION=$CCM_CASSANDRA_VERSION nosetests -s -v --logging-format="[%(levelname)s] %(asctime)s %(thread)d: %(message)s" --with-ignore-docstrings --with-xunit --xunit-file=long_results.xml tests/integration/long/ || true - - xunit: - - "*_results.xml" diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 33fe14c384..6d0744aa6e 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -22,13 +24,13 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 2, 0, 'a1', 'post0') +__version_info__ = (3, 29, 3) __version__ = '.'.join(map(str, __version_info__)) class ConsistencyLevel(object): """ - Spcifies how many replicas must respond for an operation to be considered + Specifies how many replicas must respond for an operation to be considered a success. By default, ``ONE`` is used for all operations. """ @@ -55,7 +57,7 @@ class ConsistencyLevel(object): QUORUM = 4 """ - ``ceil(RF/2)`` replicas must respond to consider the operation a success + ``ceil(RF/2) + 1`` replicas must respond to consider the operation a success """ ALL = 5 @@ -92,6 +94,11 @@ class ConsistencyLevel(object): one response. """ + @staticmethod + def is_serial(cl): + return cl == ConsistencyLevel.SERIAL or cl == ConsistencyLevel.LOCAL_SERIAL + + ConsistencyLevel.value_to_name = { ConsistencyLevel.ANY: 'ANY', ConsistencyLevel.ONE: 'ONE', @@ -125,6 +132,190 @@ def consistency_value_to_name(value): return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" +class ProtocolVersion(object): + """ + Defines native protocol versions supported by this driver. + """ + V1 = 1 + """ + v1, supported in Cassandra 1.2-->2.2 + """ + + V2 = 2 + """ + v2, supported in Cassandra 2.0-->2.2; + added support for lightweight transactions, batch operations, and automatic query paging. + """ + + V3 = 3 + """ + v3, supported in Cassandra 2.1-->3.x+; + added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), + serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. + """ + + V4 = 4 + """ + v4, supported in Cassandra 2.2-->3.x+; + added a number of new types, server warnings, new failure messages, and custom payloads. Details in the + `project docs `_ + """ + + V5 = 5 + """ + v5, in beta from 3.x+. Finalised in 4.0-beta5 + """ + + V6 = 6 + """ + v6, in beta from 4.0-beta5 + """ + + DSE_V1 = 0x41 + """ + DSE private protocol v1, supported in DSE 5.1+ + """ + + DSE_V2 = 0x42 + """ + DSE private protocol v2, supported in DSE 6.0+ + """ + + SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1) + """ + A tuple of all supported protocol versions + """ + + BETA_VERSIONS = (V6,) + """ + A tuple of all beta protocol versions + """ + + MIN_SUPPORTED = min(SUPPORTED_VERSIONS) + """ + Minimum protocol version supported by this driver. + """ + + MAX_SUPPORTED = max(SUPPORTED_VERSIONS) + """ + Maximum protocol version supported by this driver. + """ + + @classmethod + def get_lower_supported(cls, previous_version): + """ + Return the lower supported protocol version. Beta versions are omitted. + """ + try: + version = next(v for v in sorted(ProtocolVersion.SUPPORTED_VERSIONS, reverse=True) if + v not in ProtocolVersion.BETA_VERSIONS and v < previous_version) + except StopIteration: + version = 0 + + return version + + @classmethod + def uses_int_query_flags(cls, version): + return version >= cls.V5 + + @classmethod + def uses_prepare_flags(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def uses_prepared_metadata(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def uses_error_code_map(cls, version): + return version >= cls.V5 + + @classmethod + def uses_keyspace_flag(cls, version): + return version >= cls.V5 and version != cls.DSE_V1 + + @classmethod + def has_continuous_paging_support(cls, version): + return version >= cls.DSE_V1 + + @classmethod + def has_continuous_paging_next_pages(cls, version): + return version >= cls.DSE_V2 + + @classmethod + def has_checksumming_support(cls, version): + return cls.V5 <= version < cls.DSE_V1 + + +class WriteType(object): + """ + For usage with :class:`.RetryPolicy`, this describes a type + of write operation. + """ + + SIMPLE = 0 + """ + A write to a single partition key. Such writes are guaranteed to be atomic + and isolated. + """ + + BATCH = 1 + """ + A write to multiple partition keys that used the distributed batch log to + ensure atomicity. + """ + + UNLOGGED_BATCH = 2 + """ + A write to multiple partition keys that did not use the distributed batch + log. Atomicity for such writes is not guaranteed. + """ + + COUNTER = 3 + """ + A counter write (for one or multiple partition keys). Such writes should + not be replayed in order to avoid over counting. + """ + + BATCH_LOG = 4 + """ + The initial write to the distributed batch log that Cassandra performs + internally before a BATCH write. + """ + + CAS = 5 + """ + A lightweight-transaction write, such as "DELETE ... IF EXISTS". + """ + + VIEW = 6 + """ + This WriteType is only seen in results for requests that were unable to + complete MV operations. + """ + + CDC = 7 + """ + This WriteType is only seen in results for requests that were unable to + complete CDC operations. + """ + + +WriteType.name_to_value = { + 'SIMPLE': WriteType.SIMPLE, + 'BATCH': WriteType.BATCH, + 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, + 'COUNTER': WriteType.COUNTER, + 'BATCH_LOG': WriteType.BATCH_LOG, + 'CAS': WriteType.CAS, + 'VIEW': WriteType.VIEW, + 'CDC': WriteType.CDC +} + + +WriteType.value_to_name = {v: k for k, v in WriteType.name_to_value.items()} + + class SchemaChangeType(object): DROPPED = 'DROPPED' CREATED = 'CREATED' @@ -194,7 +385,21 @@ class UserAggregateDescriptor(SignatureDescriptor): """ -class Unavailable(Exception): +class DriverException(Exception): + """ + Base for all exceptions explicitly raised by the driver. + """ + pass + + +class RequestExecutionException(DriverException): + """ + Base for request execution exceptions returned from the server. + """ + pass + + +class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without @@ -220,7 +425,7 @@ def __init__(self, summary_message, consistency=None, required_replicas=None, al 'alive_replicas': alive_replicas})) -class Timeout(Exception): +class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ @@ -237,14 +442,21 @@ class Timeout(Exception): the operation """ - def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None): + def __init__(self, summary_message, consistency=None, required_responses=None, + received_responses=None, **kwargs): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses - Exception.__init__(self, summary_message + ' info=' + - repr({'consistency': consistency_value_to_name(consistency), - 'required_responses': required_responses, - 'received_responses': received_responses})) + + if "write_type" in kwargs: + kwargs["write_type"] = WriteType.value_to_name[kwargs["write_type"]] + + info = {'consistency': consistency_value_to_name(consistency), + 'required_responses': required_responses, + 'received_responses': received_responses} + info.update(kwargs) + + Exception.__init__(self, summary_message + ' info=' + repr(info)) class ReadTimeout(Timeout): @@ -285,11 +497,20 @@ class WriteTimeout(Timeout): """ def __init__(self, message, write_type=None, **kwargs): + kwargs["write_type"] = write_type Timeout.__init__(self, message, **kwargs) self.write_type = write_type -class CoordinationFailure(Exception): +class CDCWriteFailure(RequestExecutionException): + """ + Hit limit on data in CDC folder, writes are rejected + """ + def __init__(self, message): + Exception.__init__(self, message) + + +class CoordinationFailure(RequestExecutionException): """ Replicas sent a failure to the coordinator. """ @@ -311,16 +532,34 @@ class CoordinationFailure(Exception): The number of replicas that sent a failure message """ - def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, failures=None): + error_code_map = None + """ + A map of inet addresses to error codes representing replicas that sent + a failure message. Only set when `protocol_version` is 5 or higher. + """ + + def __init__(self, summary_message, consistency=None, required_responses=None, + received_responses=None, failures=None, error_code_map=None): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses self.failures = failures - Exception.__init__(self, summary_message + ' info=' + - repr({'consistency': consistency_value_to_name(consistency), - 'required_responses': required_responses, - 'received_responses': received_responses, - 'failures': failures})) + self.error_code_map = error_code_map + + info_dict = { + 'consistency': consistency_value_to_name(consistency), + 'required_responses': required_responses, + 'received_responses': received_responses, + 'failures': failures + } + + if error_code_map is not None: + # make error codes look like "0x002a" + formatted_map = dict((addr, '0x%04x' % err_code) + for (addr, err_code) in error_code_map.items()) + info_dict['error_code_map'] = formatted_map + + Exception.__init__(self, summary_message + ' info=' + repr(info_dict)) class ReadFailure(CoordinationFailure): @@ -359,7 +598,7 @@ def __init__(self, message, write_type=None, **kwargs): self.write_type = write_type -class FunctionFailure(Exception): +class FunctionFailure(RequestExecutionException): """ User Defined Function failed during execution """ @@ -386,7 +625,21 @@ def __init__(self, summary_message, keyspace, function, arg_types): Exception.__init__(self, summary_message) -class AlreadyExists(Exception): +class RequestValidationException(DriverException): + """ + Server request validation failed + """ + pass + + +class ConfigurationException(RequestValidationException): + """ + Server indicated request errro due to current configuration + """ + pass + + +class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ @@ -414,7 +667,7 @@ def __init__(self, keyspace=None, table=None): self.table = table -class InvalidRequest(Exception): +class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. @@ -422,21 +675,21 @@ class InvalidRequest(Exception): pass -class Unauthorized(Exception): +class Unauthorized(RequestValidationException): """ - The current user is not authorized to perfom the requested operation. + The current user is not authorized to perform the requested operation. """ pass -class AuthenticationFailed(Exception): +class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass -class OperationTimedOut(Exception): +class OperationTimedOut(DriverException): """ The operation took longer than the specified (client-side) timeout to complete. This is not an error generated by Cassandra, only @@ -460,10 +713,36 @@ def __init__(self, errors=None, last_host=None): Exception.__init__(self, message) -class UnsupportedOperation(Exception): +class UnsupportedOperation(DriverException): """ An attempt was made to use a feature that is not supported by the selected protocol version. See :attr:`Cluster.protocol_version` for more details. """ pass + + +class UnresolvableContactPoints(DriverException): + """ + The driver was unable to resolve any provided hostnames. + + Note that this is *not* raised when a :class:`.Cluster` is created with no + contact points, only when lookup fails for all hosts + """ + pass + +class DependencyException(Exception): + """ + Specific exception class for handling issues with driver dependencies + """ + + excs = [] + """ + A sequence of child exceptions + """ + + def __init__(self, msg, excs=[]): + complete_msg = msg + if excs: + complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs)) + Exception.__init__(self, complete_msg) diff --git a/cassandra/auth.py b/cassandra/auth.py index b562728a24..86759afe4d 100644 --- a/cassandra/auth.py +++ b/cassandra/auth.py @@ -1,20 +1,45 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging + +try: + import kerberos + _have_kerberos = True +except ImportError: + _have_kerberos = False + +try: + from puresasl.client import SASLClient + _have_puresasl = True +except ImportError: + _have_puresasl = False + try: from puresasl.client import SASLClient except ImportError: SASLClient = None +log = logging.getLogger(__name__) + +# Custom payload keys related to DSE Unified Auth +_proxy_execute_key = 'ProxyExecute' + + class AuthProvider(object): """ An abstract class that defines the interface that will be used for @@ -54,7 +79,7 @@ class Authenticator(object): 3) When the server indicates that authentication is successful, :meth:`~.on_authentication_success` will be called a token string that - that the server may optionally have sent. + the server may optionally have sent. The exact nature of the negotiation between the client and server is specific to the authentication mechanism configured server-side. @@ -67,7 +92,7 @@ class Authenticator(object): def initial_response(self): """ - Returns an message to send to the server to initiate the SASL handshake. + Returns a message to send to the server to initiate the SASL handshake. :const:`None` may be returned to send an empty message. """ return None @@ -113,22 +138,31 @@ def new_authenticator(self, host): return PlainTextAuthenticator(self.username, self.password) -class PlainTextAuthenticator(Authenticator): +class TransitionalModePlainTextAuthProvider(object): """ - An :class:`~.Authenticator` that works with Cassandra's PasswordAuthenticator. + An :class:`~.AuthProvider` that works with DSE TransitionalModePlainTextAuthenticator. - .. versionadded:: 2.0.0 - """ + Example usage:: - def __init__(self, username, password): - self.username = username - self.password = password + from cassandra.cluster import Cluster + from cassandra.auth import TransitionalModePlainTextAuthProvider - def initial_response(self): - return "\x00%s\x00%s" % (self.username, self.password) + auth_provider = TransitionalModePlainTextAuthProvider() + cluster = Cluster(auth_provider=auth_provider) - def evaluate_challenge(self, challenge): - return None + .. warning:: TransitionalModePlainTextAuthProvider will be removed in cassandra-driver + 4.0. The transitional mode will be handled internally without the need + of any auth provider. + """ + + def __init__(self): + # TODO remove next major + log.warning("TransitionalModePlainTextAuthProvider will be removed in cassandra-driver " + "4.0. The transitional mode will be handled internally without the need " + "of any auth provider.") + + def new_authenticator(self, host): + return TransitionalModePlainTextAuthenticator() class SaslAuthProvider(AuthProvider): @@ -180,3 +214,96 @@ def initial_response(self): def evaluate_challenge(self, challenge): return self.sasl.process(challenge) + +# TODO remove me next major +DSEPlainTextAuthProvider = PlainTextAuthProvider + + +class DSEGSSAPIAuthProvider(AuthProvider): + """ + Auth provider for GSS API authentication. Works with legacy `KerberosAuthenticator` + or `DseAuthenticator` if `kerberos` scheme is enabled. + """ + def __init__(self, service='dse', qops=('auth',), resolve_host_name=True, **properties): + """ + :param service: name of the service + :param qops: iterable of "Quality of Protection" allowed; see ``puresasl.QOP`` + :param resolve_host_name: boolean flag indicating whether the authenticator should reverse-lookup an FQDN when + creating a new authenticator. Default is ``True``, which will resolve, or return the numeric address if there is no PTR + record. Setting ``False`` creates the authenticator with the numeric address known by Cassandra + :param properties: additional keyword properties to pass for the ``puresasl.mechanisms.GSSAPIMechanism`` class. + Presently, 'principal' (user) is the only one referenced in the ``pure-sasl`` implementation + """ + if not _have_puresasl: + raise ImportError('The puresasl library has not been installed') + if not _have_kerberos: + raise ImportError('The kerberos library has not been installed') + self.service = service + self.qops = qops + self.resolve_host_name = resolve_host_name + self.properties = properties + + def new_authenticator(self, host): + if self.resolve_host_name: + host = socket.getnameinfo((host, 0), 0)[0] + return GSSAPIAuthenticator(host, self.service, self.qops, self.properties) + + +class BaseDSEAuthenticator(Authenticator): + def get_mechanism(self): + raise NotImplementedError("get_mechanism not implemented") + + def get_initial_challenge(self): + raise NotImplementedError("get_initial_challenge not implemented") + + def initial_response(self): + if self.server_authenticator_class == "com.datastax.bdp.cassandra.auth.DseAuthenticator": + return self.get_mechanism() + else: + return self.evaluate_challenge(self.get_initial_challenge()) + + +class PlainTextAuthenticator(BaseDSEAuthenticator): + + def __init__(self, username, password): + self.username = username + self.password = password + + def get_mechanism(self): + return b"PLAIN" + + def get_initial_challenge(self): + return b"PLAIN-START" + + def evaluate_challenge(self, challenge): + if challenge == b'PLAIN-START': + data = "\x00%s\x00%s" % (self.username, self.password) + return data.encode() + raise Exception('Did not receive a valid challenge response from server') + + +class TransitionalModePlainTextAuthenticator(PlainTextAuthenticator): + """ + Authenticator that accounts for DSE authentication is configured with transitional mode. + """ + + def __init__(self): + super(TransitionalModePlainTextAuthenticator, self).__init__('', '') + + +class GSSAPIAuthenticator(BaseDSEAuthenticator): + def __init__(self, host, service, qops, properties): + properties = properties or {} + self.sasl = SASLClient(host, service, 'GSSAPI', qops=qops, **properties) + + def get_mechanism(self): + return b"GSSAPI" + + def get_initial_challenge(self): + return b"GSSAPI-START" + + def evaluate_challenge(self, challenge): + if challenge == b'GSSAPI-START': + return self.sasl.process() + else: + return self.sasl.process(challenge) diff --git a/cassandra/buffer.pxd b/cassandra/buffer.pxd index f9976f09aa..3383fcd272 100644 --- a/cassandra/buffer.pxd +++ b/cassandra/buffer.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/bytesio.pxd b/cassandra/bytesio.pxd index a0bb083fac..24320f0ae1 100644 --- a/cassandra/bytesio.pxd +++ b/cassandra/bytesio.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/bytesio.pyx b/cassandra/bytesio.pyx index 3334697023..d9781035ef 100644 --- a/cassandra/bytesio.pyx +++ b/cassandra/bytesio.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/cluster.py b/cassandra/cluster.py index d8c1026e90..43066f73f0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,35 +21,37 @@ from __future__ import absolute_import import atexit -from collections import defaultdict, Mapping -from concurrent.futures import ThreadPoolExecutor +from binascii import hexlify +from collections import defaultdict +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from copy import copy +from functools import partial, reduce, wraps +from itertools import groupby, count, chain +import json import logging +from warnings import warn from random import random +import re +import queue import socket import sys import time from threading import Lock, RLock, Thread, Event -import warnings - -import six -from six.moves import range -from six.moves import queue as Queue +import uuid import weakref from weakref import WeakValueDictionary -try: - from weakref import WeakSet -except ImportError: - from cassandra.util import WeakSet # NOQA - -from functools import partial, wraps -from itertools import groupby, count from cassandra import (ConsistencyLevel, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, - SchemaTargetType) + SchemaTargetType, DriverException, ProtocolVersion, + UnresolvableContactPoints, DependencyException) +from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider from cassandra.connection import (ConnectionException, ConnectionShutdown, - ConnectionHeartbeat, ProtocolVersionUnsupported) + ConnectionHeartbeat, ProtocolVersionUnsupported, + EndPoint, DefaultEndPoint, DefaultEndPointFactory, + ContinuousPagingState, SniEndPointFactory, ConnectionBusy) from cassandra.cqltypes import UserType from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, @@ -58,28 +62,56 @@ PrepareMessage, ExecuteMessage, PreparedQueryNotFound, IsBootstrappingErrorMessage, + TruncateError, ServerError, BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, - RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, - ProtocolHandler) -from cassandra.metadata import Metadata, protect_name, murmur3 + RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler, + RESULT_KIND_VOID, ProtocolException) +from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, - RetryPolicy) + RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, + NoSpeculativeExecutionPolicy, DefaultLoadBalancingPolicy, + NeverRetryPolicy) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, - BatchStatement, bind_params, QueryTrace, - named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET) + BatchStatement, bind_params, QueryTrace, TraceUnavailable, + named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET, + HostTargetingStatement) +from cassandra.marshal import int64_pack +from cassandra.timestamps import MonotonicTimestampGenerator +from cassandra.util import _resolve_contact_points_to_string_map, Version + +from cassandra.datastax.insights.reporter import MonitorReporter +from cassandra.datastax.insights.util import version_supports_insights + +from cassandra.datastax.graph import (graph_object_row_factory, GraphOptions, GraphSON1Serializer, + GraphProtocol, GraphSON2Serializer, GraphStatement, SimpleGraphStatement, + graph_graphson2_row_factory, graph_graphson3_row_factory, + GraphSON3Serializer) +from cassandra.datastax.graph.query import _request_timeout_key, _GraphSONContextRowFactory +from cassandra.datastax import cloud as dscloud +try: + from cassandra.io.twistedreactor import TwistedConnection +except ImportError: + TwistedConnection = None -def _is_eventlet_monkey_patched(): - if 'eventlet.patcher' not in sys.modules: - return False - import eventlet.patcher - return eventlet.patcher.is_monkey_patched('socket') +try: + from cassandra.io.eventletreactor import EventletConnection +# PYTHON-1364 +# +# At the moment eventlet initialization is chucking AttributeErrors due to its dependence on pyOpenSSL +# and some changes in Python 3.12 which have some knock-on effects there. +except (ImportError, AttributeError): + EventletConnection = None +try: + from weakref import WeakSet +except ImportError: + from cassandra.util import WeakSet # NOQA def _is_gevent_monkey_patched(): if 'gevent.monkey' not in sys.modules: @@ -87,27 +119,68 @@ def _is_gevent_monkey_patched(): import gevent.socket return socket.socket is gevent.socket.socket -# default to gevent when we are monkey patched with gevent, eventlet when -# monkey patched with eventlet, otherwise if libev is available, use that as -# the default because it's fastest. Otherwise, use asyncore. -if _is_gevent_monkey_patched(): - from cassandra.io.geventreactor import GeventConnection as DefaultConnection -elif _is_eventlet_monkey_patched(): - from cassandra.io.eventletreactor import EventletConnection as DefaultConnection -else: +def _try_gevent_import(): + if _is_gevent_monkey_patched(): + from cassandra.io.geventreactor import GeventConnection + return (GeventConnection,None) + else: + return (None,None) + +def _is_eventlet_monkey_patched(): + if 'eventlet.patcher' not in sys.modules: + return False + try: + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + # Another case related to PYTHON-1364 + except AttributeError: + return False + +def _try_eventlet_import(): + if _is_eventlet_monkey_patched(): + from cassandra.io.eventletreactor import EventletConnection + return (EventletConnection,None) + else: + return (None,None) + +def _try_libev_import(): try: - from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA - except ImportError: - from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA + from cassandra.io.libevreactor import LibevConnection + return (LibevConnection,None) + except DependencyException as e: + return (None, e) + +def _try_asyncore_import(): + try: + from cassandra.io.asyncorereactor import AsyncoreConnection + return (AsyncoreConnection,None) + except DependencyException as e: + return (None, e) + +def _connection_reduce_fn(val,import_fn): + (rv, excs) = val + # If we've already found a workable Connection class return immediately + if rv: + return val + (import_result, exc) = import_fn() + if exc: + excs.append(exc) + return (rv or import_result, excs) + +log = logging.getLogger(__name__) + +conn_fns = (_try_gevent_import, _try_eventlet_import, _try_libev_import, _try_asyncore_import) +(conn_class, excs) = reduce(_connection_reduce_fn, conn_fns, (None,[])) +if not conn_class: + raise DependencyException("Unable to load a default connection class", excs) +DefaultConnection = conn_class # Forces load of utf8 encoding module to avoid deadlock that occurs -# if code that is being imported tries to import the module in a seperate +# if code that is being imported tries to import the module in a separate # thread. # See http://bugs.python.org/issue10923 "".encode('utf8') -log = logging.getLogger(__name__) - DEFAULT_MIN_REQUESTS = 5 DEFAULT_MAX_REQUESTS = 100 @@ -118,6 +191,7 @@ def _is_gevent_monkey_patched(): DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 +_GRAPH_PAGING_MIN_DSE_VERSION = Version('6.8.0') _NOT_SET = object() @@ -166,21 +240,359 @@ def new_f(self, *args, **kwargs): return new_f -def _shutdown_cluster(cluster): - if cluster and not cluster.is_shutdown: +_clusters_for_shutdown = set() + + +def _register_cluster_shutdown(cluster): + _clusters_for_shutdown.add(cluster) + + +def _discard_cluster_shutdown(cluster): + _clusters_for_shutdown.discard(cluster) + + +def _shutdown_clusters(): + clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" + for cluster in clusters: cluster.shutdown() -# murmur3 implementation required for TokenAware is only available for CPython -import platform -if platform.python_implementation() == 'CPython': - def default_lbp_factory(): - if murmur3 is not None: - return TokenAwarePolicy(DCAwareRoundRobinPolicy()) - return DCAwareRoundRobinPolicy() -else: - def default_lbp_factory(): - return DCAwareRoundRobinPolicy() +atexit.register(_shutdown_clusters) + + +def default_lbp_factory(): + if murmur3 is not None: + return TokenAwarePolicy(DCAwareRoundRobinPolicy()) + return DCAwareRoundRobinPolicy() + + +class ContinuousPagingOptions(object): + + class PagingUnit(object): + BYTES = 1 + ROWS = 2 + + page_unit = None + """ + Value of PagingUnit. Default is PagingUnit.ROWS. + + Units refer to the :attr:`~.Statement.fetch_size` or :attr:`~.Session.default_fetch_size`. + """ + + max_pages = None + """ + Max number of pages to send + """ + + max_pages_per_second = None + """ + Max rate at which to send pages + """ + + max_queue_size = None + """ + The maximum queue size for caching pages, only honored for protocol version DSE_V2 and higher, + by default it is 4 and it must be at least 2. + """ + + def __init__(self, page_unit=PagingUnit.ROWS, max_pages=0, max_pages_per_second=0, max_queue_size=4): + self.page_unit = page_unit + self.max_pages = max_pages + self.max_pages_per_second = max_pages_per_second + if max_queue_size < 2: + raise ValueError('ContinuousPagingOptions.max_queue_size must be 2 or greater') + self.max_queue_size = max_queue_size + + def page_unit_bytes(self): + return self.page_unit == ContinuousPagingOptions.PagingUnit.BYTES + + +def _addrinfo_or_none(contact_point, port): + """ + A helper function that wraps socket.getaddrinfo and returns None + when it fails to, e.g. resolve one of the hostnames. Used to address + PYTHON-895. + """ + try: + return socket.getaddrinfo(contact_point, port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + log.debug('Could not resolve hostname "{}" ' + 'with port {}'.format(contact_point, port)) + return None + + +def _execution_profile_to_string(name): + default_profiles = { + EXEC_PROFILE_DEFAULT: 'EXEC_PROFILE_DEFAULT', + EXEC_PROFILE_GRAPH_DEFAULT: 'EXEC_PROFILE_GRAPH_DEFAULT', + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT: 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT', + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT', + } + + if name in default_profiles: + return default_profiles[name] + + return '"%s"' % (name,) + + +class ExecutionProfile(object): + load_balancing_policy = None + """ + An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. + + Used in determining host distance for establishing connections, and routing requests. + + Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified + """ + + retry_policy = None + """ + An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a + :attr:`~.Statement.retry_policy` explicitly set. + + Defaults to :class:`.RetryPolicy` if not specified + """ + + consistency_level = ConsistencyLevel.LOCAL_ONE + """ + :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`. + """ + + serial_consistency_level = None + """ + Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements). + """ + + request_timeout = 10.0 + """ + Request timeout used when not overridden in :meth:`.Session.execute` + """ + + row_factory = staticmethod(named_tuple_factory) + """ + A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and + ``rows`` is a list of tuples, with each tuple representing a row of parsed values. + + Some example implementations: + + - :func:`cassandra.query.tuple_factory` - return a result row as a tuple + - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple + - :func:`cassandra.query.dict_factory` - return a result row as a dict + - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict + """ + + speculative_execution_policy = None + """ + An instance of :class:`.policies.SpeculativeExecutionPolicy` + + Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified + """ + + continuous_paging_options = None + """ + *Note:* This feature is implemented to facilitate server integration testing. It is not intended for general use in the Python driver. + See :attr:`.Statement.fetch_size` or :attr:`Session.default_fetch_size` for configuring normal paging. + + When set, requests will use DSE's continuous paging, which streams multiple pages without + intermediate requests. + + This has the potential to materialize all results in memory at once if the consumer cannot keep up. Use options + to constrain page size and rate. + + This is only available for DSE clusters. + """ + + # indicates if lbp was set explicitly or uses default values + _load_balancing_policy_explicit = False + _consistency_level_explicit = False + + def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None, + continuous_paging_options=None): + + if load_balancing_policy is _NOT_SET: + self._load_balancing_policy_explicit = False + self.load_balancing_policy = default_lbp_factory() + else: + self._load_balancing_policy_explicit = True + self.load_balancing_policy = load_balancing_policy + + if consistency_level is _NOT_SET: + self._consistency_level_explicit = False + self.consistency_level = ConsistencyLevel.LOCAL_ONE + else: + self._consistency_level_explicit = True + self.consistency_level = consistency_level + + self.retry_policy = retry_policy or RetryPolicy() + + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): + raise ValueError("serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") + self.serial_consistency_level = serial_consistency_level + + self.request_timeout = request_timeout + self.row_factory = row_factory + self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() + self.continuous_paging_options = continuous_paging_options + + +class GraphExecutionProfile(ExecutionProfile): + graph_options = None + """ + :class:`.GraphOptions` to use with this execution + + Default options for graph queries, initialized as follows by default:: + + GraphOptions(graph_language=b'gremlin-groovy') + + See cassandra.graph.GraphOptions + """ + + def __init__(self, load_balancing_policy=_NOT_SET, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=30.0, row_factory=None, + graph_options=None, continuous_paging_options=_NOT_SET): + """ + Default execution profile for graph execution. + + See :class:`.ExecutionProfile` for base attributes. Note that if not explicitly set, + the row_factory and graph_options.graph_protocol are resolved during the query execution. + These options will resolve to graph_graphson3_row_factory and GraphProtocol.GRAPHSON_3_0 + for the core graph engine (DSE 6.8+), otherwise graph_object_row_factory and GraphProtocol.GRAPHSON_1_0 + + In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to + :class:`cassandra.policies.NeverRetryPolicy`. + """ + retry_policy = retry_policy or NeverRetryPolicy() + super(GraphExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, + serial_consistency_level, request_timeout, row_factory, + continuous_paging_options=continuous_paging_options) + self.graph_options = graph_options or GraphOptions(graph_source=b'g', + graph_language=b'gremlin-groovy') + + +class GraphAnalyticsExecutionProfile(GraphExecutionProfile): + + def __init__(self, load_balancing_policy=None, retry_policy=None, + consistency_level=_NOT_SET, serial_consistency_level=None, + request_timeout=3600. * 24. * 7., row_factory=None, + graph_options=None): + """ + Execution profile with timeout and load balancing appropriate for graph analytics queries. + + See also :class:`~.GraphExecutionPolicy`. + + In addition to default parameters shown in the signature, this profile also defaults ``retry_policy`` to + :class:`cassandra.policies.NeverRetryPolicy`, and ``load_balancing_policy`` to one that targets the current Spark + master. + + Note: The graph_options.graph_source is set automatically to b'a' (analytics) + when using GraphAnalyticsExecutionProfile. This is mandatory to target analytics nodes. + """ + load_balancing_policy = load_balancing_policy or DefaultLoadBalancingPolicy(default_lbp_factory()) + graph_options = graph_options or GraphOptions(graph_language=b'gremlin-groovy') + super(GraphAnalyticsExecutionProfile, self).__init__(load_balancing_policy, retry_policy, consistency_level, + serial_consistency_level, request_timeout, row_factory, + graph_options) + # ensure the graph_source is analytics, since this is the purpose of the GraphAnalyticsExecutionProfile + self.graph_options.set_source_analytics() + + +class ProfileManager(object): + + def __init__(self): + self.profiles = dict() + + def _profiles_without_explicit_lbps(self): + names = (profile_name for + profile_name, profile in self.profiles.items() + if not profile._load_balancing_policy_explicit) + return tuple( + 'EXEC_PROFILE_DEFAULT' if n is EXEC_PROFILE_DEFAULT else n + for n in names + ) + + def distance(self, host): + distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) + return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ + HostDistance.REMOTE if HostDistance.REMOTE in distances else \ + HostDistance.IGNORED + + def populate(self, cluster, hosts): + for p in self.profiles.values(): + p.load_balancing_policy.populate(cluster, hosts) + + def check_supported(self): + for p in self.profiles.values(): + p.load_balancing_policy.check_supported() + + def on_up(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_up(host) + + def on_down(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_down(host) + + def on_add(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_add(host) + + def on_remove(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_remove(host) + + @property + def default(self): + """ + internal-only; no checks are done because this entry is populated on cluster init + """ + return self.profiles[EXEC_PROFILE_DEFAULT] + + +EXEC_PROFILE_DEFAULT = object() +""" +Key for the ``Cluster`` default execution profile, used when no other profile is selected in +``Session.execute(execution_profile)``. + +Use this as the key in ``Cluster(execution_profiles)`` to override the default profile. +""" + +EXEC_PROFILE_GRAPH_DEFAULT = object() +""" +Key for the default graph execution profile, used when no other profile is selected in +``Session.execute_graph(execution_profile)``. + +Use this as the key in :doc:`Cluster(execution_profiles) ` +to override the default graph profile. +""" + +EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT = object() +""" +Key for the default graph system execution profile. This can be used for graph statements using the DSE graph +system API. + +Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)``. +""" + +EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT = object() +""" +Key for the default graph analytics execution profile. This can be used for graph statements intended to +use Spark/analytics as the traversal source. + +Selected using ``Session.execute_graph(execution_profile=EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT)``. +""" + + +class _ConfigMode(object): + UNCOMMITTED = 0 + LEGACY = 1 + PROFILES = 2 class Cluster(object): @@ -198,11 +610,15 @@ class Cluster(object): >>> ... >>> cluster.shutdown() + ``Cluster`` and ``Session`` also provide context management functions + which implicitly handle shutdown when leaving scope. """ contact_points = ['127.0.0.1'] """ - The list of contact points to try connecting for cluster discovery. + The list of contact points to try connecting for cluster discovery. A + contact point can be a string (ip or hostname), a tuple (ip/hostname, port) or a + :class:`.connection.EndPoint` instance. Defaults to loopback interface. @@ -210,7 +626,14 @@ class Cluster(object): local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. + + Note: In the next major version, if you specify contact points, you will + also be required to also explicitly specify a load-balancing policy. This + change will help prevent cases where users had hard-to-debug issues + surrounding unintuitive default load-balancing policy behavior. """ + # tracks if contact_points was set explicitly or with default values + _contact_points_explicit = None port = 9042 """ @@ -224,42 +647,27 @@ class Cluster(object): server will be automatically used. """ - protocol_version = 4 + protocol_version = ProtocolVersion.DSE_V2 """ The maximum version of the native protocol to use. - The driver will automatically downgrade version based on a negotiation with - the server, but it is most efficient to set this to the maximum supported - by your version of Cassandra. Setting this will also prevent conflicting - versions negotiated if your cluster is upgraded. + See :class:`.ProtocolVersion` for more information about versions. - Version 2 of the native protocol adds support for lightweight transactions, - batch operations, and automatic query paging. The v2 protocol is - supported by Cassandra 2.0+. + If not set in the constructor, the driver will automatically downgrade + version based on a negotiation with the server, but it is most efficient + to set this to the maximum supported by your version of Cassandra. + Setting this will also prevent conflicting versions negotiated if your + cluster is upgraded. - Version 3 of the native protocol adds support for protocol-level - client-side timestamps (see :attr:`.Session.use_client_timestamp`), - serial consistency levels for :class:`~.BatchStatement`, and an - improved connection pool. + """ - Version 4 of the native protocol adds a number of new types, server warnings, - new failure messages, and custom payloads. Details in the - `project docs `_ + allow_beta_protocol_version = False - The following table describes the native protocol versions that - are supported by each version of Cassandra: + no_compact = False - +-------------------+-------------------+ - | Cassandra Version | Protocol Versions | - +===================+===================+ - | 1.2 | 1 | - +-------------------+-------------------+ - | 2.0 | 1, 2 | - +-------------------+-------------------+ - | 2.1 | 1, 2, 3 | - +-------------------+-------------------+ - | 2.2 | 1, 2, 3, 4 | - +-------------------+-------------------+ + """ + Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version. + Used for testing new protocol features incrementally before the new version is complete. """ compression = True @@ -311,20 +719,34 @@ def auth_provider(self, value): self._auth_provider = value - load_balancing_policy = None - """ - An instance of :class:`.policies.LoadBalancingPolicy` or - one of its subclasses. + _load_balancing_policy = None + @property + def load_balancing_policy(self): + """ + An instance of :class:`.policies.LoadBalancingPolicy` or + one of its subclasses. - .. versionchanged:: 2.6.0 + .. versionchanged:: 2.6.0 - Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). - when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` - otherwise. Default local DC will be chosen from contact points. + Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). + when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` + otherwise. Default local DC will be chosen from contact points. - **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to - DC locality and remote nodes.** - """ + **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to + DC locality and remote nodes.** + """ + return self._load_balancing_policy + + @load_balancing_policy.setter + def load_balancing_policy(self, lbp): + if self._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") + self._load_balancing_policy = lbp + self._config_mode = _ConfigMode.LEGACY + + @property + def _default_load_balancing_policy(self): + return self.profile_manager.default.load_balancing_policy reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) """ @@ -333,12 +755,22 @@ def auth_provider(self, value): a max delay of ten minutes. """ - default_retry_policy = RetryPolicy() - """ - A default :class:`.policies.RetryPolicy` instance to use for all - :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` - explicitly set. - """ + _default_retry_policy = RetryPolicy() + @property + def default_retry_policy(self): + """ + A default :class:`.policies.RetryPolicy` instance to use for all + :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` + explicitly set. + """ + return self._default_retry_policy + + @default_retry_policy.setter + def default_retry_policy(self, policy): + if self._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") + self._default_retry_policy = policy + self._config_mode = _ConfigMode.LEGACY conviction_policy_factory = SimpleConvictionPolicy """ @@ -347,6 +779,12 @@ def auth_provider(self, value): :class:`.policies.SimpleConvictionPolicy`. """ + address_translator = IdentityTranslator() + """ + :class:`.policies.AddressTranslator` instance to be used in translating server node addresses + to driver connection addresses. + """ + connect_to_remote_hosts = True """ If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` @@ -372,14 +810,44 @@ def auth_provider(self, value): ssl_options = None """ - A optional dict which will be used as kwargs for ``ssl.wrap_socket()`` - when new sockets are created. This should be used when client encryption - is enabled in Cassandra. + Using ssl_options without ssl_context is deprecated and will be removed in the + next major release. + + An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket`` + when new sockets are created. This should be used when client encryption is enabled + in Cassandra. + + The following documentation only applies when ssl_options is used without ssl_context. By default, a ``ca_certs`` value should be supplied (the value should be a string pointing to the location of the CA certs file), and you probably - want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match + want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match Cassandra's default protocol. + + .. versionchanged:: 3.3.0 + + In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname + as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so + should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into + Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` + with a custom or `back-ported function `_. + + .. versionchanged:: 3.29.0 + + ``ssl.match_hostname`` has been deprecated since Python 3.7 (and removed in Python 3.12). This functionality is now implemented + via ``ssl.SSLContext.check_hostname``. All options specified above (including ``check_hostname``) should continue to behave in a + way that is consistent with prior implementations. + """ + + ssl_context = None + """ + An optional ``ssl.SSLContext`` instance which will be used when new sockets are created. + This should be used when client encryption is enabled in Cassandra. + + ``wrap_socket`` options can be set using :attr:`~Cluster.ssl_options`. ssl_options will + be used as kwargs for ``ssl.SSLContext.wrap_socket``. + + .. versionadded:: 3.17.0 """ sockopts = None @@ -414,6 +882,7 @@ def auth_provider(self, value): * :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.twistedreactor.TwistedConnection` + * EXPERIMENTAL: :class:`cassandra.io.asyncioreactor.AsyncioConnection` By default, ``AsyncoreConnection`` will be used, which uses the ``asyncore`` module in the Python standard library. @@ -422,6 +891,11 @@ def auth_provider(self, value): If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding connection class will be used automatically. + + ``AsyncioConnection``, which uses the ``asyncio`` module in the Python + standard library, is also available, but currently experimental. Note that + it requires ``asyncio`` features that were only introduced in the 3.4 line + in 3.4.6, and in the 3.5 line in 3.5.1. """ control_connection_timeout = 2.0 @@ -439,6 +913,12 @@ def auth_provider(self, value): Setting to zero disables heartbeats. """ + idle_heartbeat_timeout = 30 + """ + Timeout, in seconds, on which the heartbeat wait for idle connection responses. + Lowering this value can help to discover bad connections earlier. + """ + schema_event_refresh_window = 2 """ Window, in seconds, within which a schema component will be refreshed after @@ -467,12 +947,41 @@ def auth_provider(self, value): Setting this to zero will execute refreshes immediately. - Setting this negative will disable node refreshes in response to push events - (refreshes will still occur in response to new nodes observed on "UP" events). + Setting this negative will disable node refreshes in response to push events. See :attr:`.schema_event_refresh_window` for discussion of rationale """ + status_event_refresh_window = 2 + """ + Window, in seconds, within which the driver will start the reconnect after + receiving a status_change event. + + Setting this to zero will connect immediately. + + This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients. + When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and + establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor. + """ + + prepare_on_all_hosts = True + """ + Specifies whether statements should be prepared on all hosts, or just one. + + This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup, + where a randomized initial condition of the load balancing policy can be expected to distribute prepares from + different clients across the cluster. + """ + + reprepare_on_up = True + """ + Specifies whether all known prepared statements should be prepared on a node when it comes up. + + May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to + network. If statements are not reprepared, they are prepared on the first execution, causing + an extra roundtrip for one or more client requests. + """ + connect_timeout = 5 """ Timeout, in seconds, for creating new connections. @@ -481,6 +990,108 @@ def auth_provider(self, value): establishment, options passing, and authentication. """ + timestamp_generator = None + """ + An object, shared between all sessions created by this cluster instance, + that generates timestamps when client-side timestamp generation is enabled. + By default, each :class:`Cluster` uses a new + :class:`~.MonotonicTimestampGenerator`. + + Applications can set this value for custom timestamp behavior. See the + documentation for :meth:`Session.timestamp_generator`. + """ + + monitor_reporting_enabled = True + """ + A boolean indicating if monitor reporting, which sends gathered data to + Insights when running against DSE 6.8 and higher. + """ + + monitor_reporting_interval = 30 + """ + A boolean indicating if monitor reporting, which sends gathered data to + Insights when running against DSE 6.8 and higher. + """ + + client_id = None + """ + A UUID that uniquely identifies this Cluster object to Insights. This will + be generated automatically unless the user provides one. + """ + + application_name = '' + """ + A string identifying this application to Insights. + """ + + application_version = '' + """ + A string identifying this application's version to Insights + """ + + cloud = None + """ + A dict of the cloud configuration. Example:: + + { + # path to the secure connect bundle + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + + # optional config options + 'use_default_tempdir': True # use the system temp dir for the zip extraction + } + + The zip file will be temporarily extracted in the same directory to + load the configuration and certificates. + """ + + column_encryption_policy = None + """ + An instance of :class:`cassandra.policies.ColumnEncryptionPolicy` specifying encryption materials to be + used for columns in this cluster. + """ + + @property + def schema_metadata_enabled(self): + """ + Flag indicating whether internal schema metadata is updated. + + When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This + can be used to speed initial connection, and reduce load on client and server during operation. Turning this off + gives away token aware request routing, and programmatic inspection of the metadata model. + """ + return self.control_connection._schema_meta_enabled + + @schema_metadata_enabled.setter + def schema_metadata_enabled(self, enabled): + self.control_connection._schema_meta_enabled = bool(enabled) + + @property + def token_metadata_enabled(self): + """ + Flag indicating whether internal token metadata is updated. + + When disabled, the driver does not query node token information on connect, or on topology change events. This + can be used to speed initial connection, and reduce load on client and server during operation. It is most useful + in large clusters using vnodes, where the token map can be expensive to compute. Turning this off + gives away token aware request routing, and programmatic inspection of the token ring. + """ + return self.control_connection._token_meta_enabled + + @token_metadata_enabled.setter + def token_metadata_enabled(self, enabled): + self.control_connection._token_meta_enabled = bool(enabled) + + endpoint_factory = None + """ + An :class:`~.connection.EndPointFactory` instance to use internally when creating + a socket connection to a node. You can ignore this unless you need a special + connection mechanism. + """ + + profile_manager = None + _config_mode = _ConfigMode.UNCOMMITTED + sessions = None control_connection = None scheduler = None @@ -490,6 +1101,8 @@ def auth_provider(self, value): _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None + _protocol_version_explicit = False + _discount_down_events = True _user_types = None """ @@ -500,7 +1113,7 @@ def auth_provider(self, value): _listener_lock = None def __init__(self, - contact_points=["127.0.0.1"], + contact_points=_NOT_SET, port=9042, compression=True, auth_provider=None, @@ -513,47 +1126,138 @@ def __init__(self, ssl_options=None, sockopts=None, cql_version=None, - protocol_version=4, + protocol_version=_NOT_SET, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, - connect_timeout=5): + connect_timeout=5, + schema_metadata_enabled=True, + token_metadata_enabled=True, + address_translator=None, + status_event_refresh_window=2, + prepare_on_all_hosts=True, + reprepare_on_up=True, + execution_profiles=None, + allow_beta_protocol_version=False, + timestamp_generator=None, + idle_heartbeat_timeout=30, + no_compact=False, + ssl_context=None, + endpoint_factory=None, + application_name=None, + application_version=None, + monitor_reporting_enabled=True, + monitor_reporting_interval=30, + client_id=None, + cloud=None, + column_encryption_policy=None): """ - Any of the mutable Cluster attributes may be set as keyword arguments - to the constructor. + ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as + establishing connection pools or refreshing metadata. + + Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. """ + if connection_class is not None: + self.connection_class = connection_class + + if cloud is not None: + self.cloud = cloud + if contact_points is not _NOT_SET or endpoint_factory or ssl_context or ssl_options: + raise ValueError("contact_points, endpoint_factory, ssl_context, and ssl_options " + "cannot be specified with a cloud configuration") + + uses_twisted = TwistedConnection and issubclass(self.connection_class, TwistedConnection) + uses_eventlet = EventletConnection and issubclass(self.connection_class, EventletConnection) + cloud_config = dscloud.get_cloud_config(cloud, create_pyopenssl_context=uses_twisted or uses_eventlet) + + ssl_context = cloud_config.ssl_context + ssl_options = {'check_hostname': True} + if (auth_provider is None and cloud_config.username + and cloud_config.password): + auth_provider = PlainTextAuthProvider(cloud_config.username, cloud_config.password) + + endpoint_factory = SniEndPointFactory(cloud_config.sni_host, cloud_config.sni_port) + contact_points = [ + endpoint_factory.create_from_sni(host_id) + for host_id in cloud_config.host_ids + ] + if contact_points is not None: - if isinstance(contact_points, six.string_types): + if contact_points is _NOT_SET: + self._contact_points_explicit = False + contact_points = ['127.0.0.1'] + else: + self._contact_points_explicit = True + + if isinstance(contact_points, str): raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") + if None in contact_points: + raise ValueError("contact_points should not contain None (it can resolve to localhost)") self.contact_points = contact_points self.port = port + + if column_encryption_policy is not None: + self.column_encryption_policy = column_encryption_policy + + self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) + self.endpoint_factory.configure(self) + + raw_contact_points = [] + for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]: + raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, port)) + + self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] + self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port) + for ep in self.endpoints_resolved} + + strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points) + self.endpoints_resolved.extend(list(chain( + *[ + [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] + for xs in strs_resolved_map.values() if xs is not None + ] + ))) + + self._endpoint_map_for_insights.update( + {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value] + for key, value in strs_resolved_map.items() if value is not None} + ) + + if contact_points and (not self.endpoints_resolved): + # only want to raise here if the user specified CPs but resolution failed + raise UnresolvableContactPoints(self._endpoint_map_for_insights) + self.compression = compression - self.protocol_version = protocol_version + + if protocol_version is not _NOT_SET: + self.protocol_version = protocol_version + self._protocol_version_explicit = True + self.allow_beta_protocol_version = allow_beta_protocol_version + + self.no_compact = no_compact + self.auth_provider = auth_provider if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") - self.load_balancing_policy = load_balancing_policy else: - self.load_balancing_policy = default_lbp_factory() + self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") - self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") - self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: @@ -561,19 +1265,97 @@ def __init__(self, raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory - if connection_class is not None: - self.connection_class = connection_class + if address_translator is not None: + if isinstance(address_translator, type): + raise TypeError("address_translator should not be a class, it should be an instance of that class") + self.address_translator = address_translator + + if timestamp_generator is not None: + if not callable(timestamp_generator): + raise ValueError("timestamp_generator must be callable") + self.timestamp_generator = timestamp_generator + else: + self.timestamp_generator = MonotonicTimestampGenerator() + + self.profile_manager = ProfileManager() + self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( + self.load_balancing_policy, + self.default_retry_policy, + request_timeout=Session._default_timeout, + row_factory=Session._row_factory + ) + + # legacy mode if either of these is not default + if load_balancing_policy or default_retry_policy: + if execution_profiles: + raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " + "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") + + self._config_mode = _ConfigMode.LEGACY + warn("Legacy execution parameters will be removed in 4.0. Consider using " + "execution profiles.", DeprecationWarning) + + else: + profiles = self.profile_manager.profiles + if execution_profiles: + profiles.update(execution_profiles) + self._config_mode = _ConfigMode.PROFILES + + lbp = DefaultLoadBalancingPolicy(self.profile_manager.default.load_balancing_policy) + profiles.setdefault(EXEC_PROFILE_GRAPH_DEFAULT, GraphExecutionProfile(load_balancing_policy=lbp)) + profiles.setdefault(EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + GraphExecutionProfile(load_balancing_policy=lbp, request_timeout=60. * 3.)) + profiles.setdefault(EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + GraphAnalyticsExecutionProfile(load_balancing_policy=lbp)) + + if self._contact_points_explicit and not self.cloud: # avoid this warning for cloud users. + if self._config_mode is _ConfigMode.PROFILES: + default_lbp_profiles = self.profile_manager._profiles_without_explicit_lbps() + if default_lbp_profiles: + log.warning( + 'Cluster.__init__ called with contact_points ' + 'specified, but load-balancing policies are not ' + 'specified in some ExecutionProfiles. In the next ' + 'major version, this will raise an error; please ' + 'specify a load-balancing policy. ' + '(contact_points = {cp}, ' + 'EPs without explicit LBPs = {eps})' + ''.format(cp=contact_points, eps=default_lbp_profiles)) + else: + if load_balancing_policy is None: + log.warning( + 'Cluster.__init__ called with contact_points ' + 'specified, but no load_balancing_policy. In the next ' + 'major version, this will raise an error; please ' + 'specify a load-balancing policy. ' + '(contact_points = {cp}, lbp = {lbp})' + ''.format(cp=contact_points, lbp=load_balancing_policy)) self.metrics_enabled = metrics_enabled + + if ssl_options and not ssl_context: + warn('Using ssl_options without ssl_context is ' + 'deprecated and will result in an error in ' + 'the next major release. Please use ssl_context ' + 'to prepare for that release.', + DeprecationWarning) + self.ssl_options = ssl_options + self.ssl_context = ssl_context self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout self.idle_heartbeat_interval = idle_heartbeat_interval + self.idle_heartbeat_timeout = idle_heartbeat_timeout self.schema_event_refresh_window = schema_event_refresh_window self.topology_event_refresh_window = topology_event_refresh_window + self.status_event_refresh_window = status_event_refresh_window self.connect_timeout = connect_timeout + self.prepare_on_all_hosts = prepare_on_all_hosts + self.reprepare_on_up = reprepare_on_up + self.monitor_reporting_enabled = monitor_reporting_enabled + self.monitor_reporting_interval = monitor_reporting_interval self._listeners = set() self._listener_lock = Lock() @@ -608,7 +1390,7 @@ def __init__(self, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } - self.executor = ThreadPoolExecutor(max_workers=executor_threads) + self.executor = self._create_thread_pool_executor(max_workers=executor_threads) self.scheduler = _Scheduler(self.executor) self._lock = RLock() @@ -619,7 +1401,52 @@ def __init__(self, self.control_connection = ControlConnection( self, self.control_connection_timeout, - self.schema_event_refresh_window, self.topology_event_refresh_window) + self.schema_event_refresh_window, self.topology_event_refresh_window, + self.status_event_refresh_window, + schema_metadata_enabled, token_metadata_enabled) + + if client_id is None: + self.client_id = uuid.uuid4() + if application_name is not None: + self.application_name = application_name + if application_version is not None: + self.application_version = application_version + + def _create_thread_pool_executor(self, **kwargs): + """ + Create a ThreadPoolExecutor for the cluster. In most cases, the built-in + `concurrent.futures.ThreadPoolExecutor` is used. + + Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` + to hang indefinitely. In that case, the user needs to have the `futurist` + package so we can use the `futurist.GreenThreadPoolExecutor` class instead. + + :param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor. + :return: A ThreadPoolExecutor instance. + """ + tpe_class = ThreadPoolExecutor + if sys.version_info[0] >= 3 and sys.version_info[1] >= 7: + try: + from cassandra.io.eventletreactor import EventletConnection + is_eventlet = issubclass(self.connection_class, EventletConnection) + except: + # Eventlet is not available or can't be detected + return tpe_class(**kwargs) + + if is_eventlet: + try: + from futurist import GreenThreadPoolExecutor + tpe_class = GreenThreadPoolExecutor + except ImportError: + # futurist is not available + raise ImportError( + ("Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` " + "to hang indefinitely. If you want to use the Eventlet reactor, you " + "need to install the `futurist` package to allow the driver to use " + "the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 " + "for more details.")) + + return tpe_class(**kwargs) def register_user_type(self, keyspace, user_type, klass): """ @@ -636,7 +1463,7 @@ def register_user_type(self, keyspace, user_type, klass): for. `klass` should be a class with attributes whose names match the - fields of the user-defined type. The constructor must accepts kwargs + fields of the user-defined type. The constructor must accept kwargs for each of the fields in the UDT. This method should only be called after the type has been created @@ -666,7 +1493,7 @@ def __init__(self, street, zipcode): # results will include Address instances results = session.execute("SELECT * FROM users") row = results[0] - print row.id, row.location.street, row.location.zipcode + print(row.id, row.location.street, row.location.zipcode) """ if self.protocol_version < 3: @@ -675,10 +1502,58 @@ def __init__(self, street, zipcode): "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) self._user_types[keyspace][user_type] = klass - for session in self.sessions: + for session in tuple(self.sessions): session.user_type_registered(keyspace, user_type, klass) UserType.evict_udt_class(keyspace, user_type) + def add_execution_profile(self, name, profile, pool_wait_timeout=5): + """ + Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute` + and :meth:`.Session.execute_async`. This method will raise if the profile already exists. + + Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method + provides a way of adding them dynamically. + + Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default, + this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately + upon return. This behavior can be controlled using ``pool_wait_timeout`` (see + `concurrent.futures.wait `_ + for timeout semantics). + """ + if not isinstance(profile, ExecutionProfile): + raise TypeError("profile must be an instance of ExecutionProfile") + if self._config_mode == _ConfigMode.LEGACY: + raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly.") + if name in self.profile_manager.profiles: + raise ValueError("Profile {} already exists".format(name)) + contact_points_but_no_lbp = ( + self._contact_points_explicit and not + profile._load_balancing_policy_explicit) + if contact_points_but_no_lbp: + log.warning( + 'Tried to add an ExecutionProfile with name {name}. ' + '{self} was explicitly configured with contact_points, but ' + '{ep} was not explicitly configured with a ' + 'load_balancing_policy. In the next major version, trying to ' + 'add an ExecutionProfile without an explicitly configured LBP ' + 'to a cluster with explicitly configured contact_points will ' + 'raise an exception; please specify a load-balancing policy ' + 'in the ExecutionProfile.' + ''.format(name=_execution_profile_to_string(name), self=self, ep=profile)) + + self.profile_manager.profiles[name] = profile + profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) + # on_up after populate allows things like DCA LBP to choose default local dc + for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): + profile.load_balancing_policy.on_up(host) + futures = set() + for session in tuple(self.sessions): + self._set_default_dbaas_consistency(session) + futures.update(session.update_created_pools()) + _, not_done = wait_futures(futures, pool_wait_timeout) + if not_done: + raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") + def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] @@ -694,6 +1569,10 @@ def set_min_requests_per_connection(self, host_distance, min_requests): raise UnsupportedOperation( "Cluster.set_min_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") + if min_requests < 0 or min_requests > 126 or \ + min_requests >= self._max_requests_per_connection[host_distance]: + raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" % + (self._min_requests_per_connection[host_distance],)) self._min_requests_per_connection[host_distance] = min_requests def get_max_requests_per_connection(self, host_distance): @@ -711,6 +1590,10 @@ def set_max_requests_per_connection(self, host_distance, max_requests): raise UnsupportedOperation( "Cluster.set_max_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") + if max_requests < 1 or max_requests > 127 or \ + max_requests <= self._min_requests_per_connection[host_distance]: + raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" % + (self._min_requests_per_connection[host_distance],)) self._max_requests_per_connection[host_distance] = max_requests def get_core_connections_per_host(self, host_distance): @@ -739,7 +1622,7 @@ def set_core_connections_per_host(self, host_distance, core_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupporteOperation`. + and using this will result in an :exc:`~.UnsupportedOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -772,7 +1655,7 @@ def set_max_connections_per_host(self, host_distance, max_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupporteOperation`. + and using this will result in an :exc:`~.UnsupportedOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -780,68 +1663,90 @@ def set_max_connections_per_host(self, host_distance, max_connections): "when using protocol_version 1 or 2.") self._max_connections_per_host[host_distance] = max_connections - def connection_factory(self, address, *args, **kwargs): + def connection_factory(self, endpoint, *args, **kwargs): """ Called to create a new connection with proper configuration. Intended for internal use only. """ - kwargs = self._make_connection_kwargs(address, kwargs) - return self.connection_class.factory(address, self.connect_timeout, *args, **kwargs) + kwargs = self._make_connection_kwargs(endpoint, kwargs) + return self.connection_class.factory(endpoint, self.connect_timeout, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): - kwargs = self._make_connection_kwargs(host.address, kwargs) - return partial(self.connection_class.factory, host.address, self.connect_timeout, *args, **kwargs) + kwargs = self._make_connection_kwargs(host.endpoint, kwargs) + return partial(self.connection_class.factory, host.endpoint, self.connect_timeout, *args, **kwargs) - def _make_connection_kwargs(self, address, kwargs_dict): + def _make_connection_kwargs(self, endpoint, kwargs_dict): if self._auth_provider_callable: - kwargs_dict.setdefault('authenticator', self._auth_provider_callable(address)) + kwargs_dict.setdefault('authenticator', self._auth_provider_callable(endpoint.address)) kwargs_dict.setdefault('port', self.port) kwargs_dict.setdefault('compression', self.compression) kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) + kwargs_dict.setdefault('ssl_context', self.ssl_context) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) + kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) + kwargs_dict.setdefault('no_compact', self.no_compact) return kwargs_dict - def protocol_downgrade(self, host_addr, previous_version): - new_version = previous_version - 1 - if new_version < self.protocol_version: - if new_version >= MIN_SUPPORTED_VERSION: - log.warning("Downgrading core protocol version from %d to %d for %s", self.protocol_version, new_version, host_addr) - self.protocol_version = new_version - else: - raise Exception("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) + def protocol_downgrade(self, host_endpoint, previous_version): + if self._protocol_version_explicit: + raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) + new_version = ProtocolVersion.get_lower_supported(previous_version) + if new_version < ProtocolVersion.MIN_SUPPORTED: + raise DriverException( + "Cannot downgrade protocol version below minimum supported version: %d" % (ProtocolVersion.MIN_SUPPORTED,)) + + log.warning("Downgrading core protocol version from %d to %d for %s. " + "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " + "https://docs.datastax.com/en/developer/python-driver/latest/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) + self.protocol_version = new_version - def connect(self, keyspace=None): + def connect(self, keyspace=None, wait_for_all_pools=False): """ - Creates and returns a new :class:`~.Session` object. If `keyspace` - is specified, that keyspace will be the default keyspace for + Creates and returns a new :class:`~.Session` object. + + If `keyspace` is specified, that keyspace will be the default keyspace for operations on the ``Session``. + + `wait_for_all_pools` specifies whether this call should wait for all connection pools to be + established or attempted. Default is `False`, which means it will return when the first + successful connection is established. Remaining pools are added asynchronously. """ with self._lock: if self.is_shutdown: - raise Exception("Cluster is already shut down") + raise DriverException("Cluster is already shut down") if not self._is_setup: log.debug("Connecting to cluster, contact points: %s; protocol version: %s", self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() - atexit.register(partial(_shutdown_cluster, self)) - for address in self.contact_points: - host, new = self.add_host(address, signal=False) + _register_cluster_shutdown(self) + for endpoint in self.endpoints_resolved: + host, new = self.add_host(endpoint, signal=False) if new: host.set_up() for listener in self.listeners: listener.on_add(host) - self.load_balancing_policy.populate( + self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) + self.load_balancing_policy.populate( + weakref.proxy(self), self.metadata.all_hosts() + ) try: self.control_connection.connect() + + # we set all contact points up for connecting, but we won't infer state after this + for endpoint in self.endpoints_resolved: + h = self.metadata.get_host(endpoint) + if h and self.profile_manager.distance(h) == HostDistance.IGNORED: + h.is_up = None + log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " @@ -849,20 +1754,34 @@ def connect(self, keyspace=None): self.shutdown() raise - self.load_balancing_policy.check_supported() + self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: - self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) + self._idle_heartbeat = ConnectionHeartbeat( + self.idle_heartbeat_interval, + self.get_connection_holders, + timeout=self.idle_heartbeat_timeout + ) self._is_setup = True - session = self._new_session() - if keyspace: - session.set_keyspace(keyspace) + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + return session + def _set_default_dbaas_consistency(self, session): + if session.cluster.metadata.dbaas: + for profile in self.profile_manager.profiles.values(): + if not profile._consistency_level_explicit: + profile.consistency_level = ConsistencyLevel.LOCAL_QUORUM + session._default_consistency_level = ConsistencyLevel.LOCAL_QUORUM + def get_connection_holders(self): holders = [] - for s in self.sessions: + for s in tuple(self.sessions): holders.extend(s.get_pools()) holders.append(self.control_connection) return holders @@ -888,26 +1807,34 @@ def shutdown(self): self.control_connection.shutdown() - for session in self.sessions: + for session in tuple(self.sessions): session.shutdown() self.executor.shutdown() - def _new_session(self): - session = Session(self, self.metadata.all_hosts()) + _discard_cluster_shutdown(self) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown() + + def _new_session(self, keyspace): + session = Session(self, self.metadata.all_hosts(), keyspace) self._session_register_user_types(session) self.sessions.add(session) return session def _session_register_user_types(self, session): - for keyspace, type_map in six.iteritems(self._user_types): - for udt_name, klass in six.iteritems(type_map): + for keyspace, type_map in self._user_types.items(): + for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) def _cleanup_failed_on_up_handling(self, host): - self.load_balancing_policy.on_down(host) + self.profile_manager.on_down(host) self.control_connection.on_down(host) - for session in self.sessions: + for session in tuple(self.sessions): session.remove_pool(host) self._start_reconnector(host, is_host_addition=False) @@ -946,7 +1873,7 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) host._currently_handling_node_up = False # see if there are any pools to add or remove now that the host is marked up - for session in self.sessions: + for session in tuple(self.sessions): session.update_created_pools() def on_up(self, host): @@ -979,14 +1906,15 @@ def on_up(self, host): log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() - self._prepare_all_queries(host) - log.debug("Done preparing all queries for host %s, ", host) + if self.profile_manager.distance(host) != HostDistance.IGNORED: + self._prepare_all_queries(host) + log.debug("Done preparing all queries for host %s, ", host) - for session in self.sessions: + for session in tuple(self.sessions): session.remove_pool(host) - log.debug("Signalling to load balancing policy that host %s is up", host) - self.load_balancing_policy.on_up(host) + log.debug("Signalling to load balancing policies that host %s is up", host) + self.profile_manager.on_up(host) log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) @@ -995,7 +1923,7 @@ def on_up(self, host): futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) - for session in self.sessions: + for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True @@ -1014,13 +1942,14 @@ def on_up(self, host): else: if not have_future: with host.lock: + host.set_up() host._currently_handling_node_up = False # for testing purposes return futures def _start_reconnector(self, host, is_host_addition): - if self.load_balancing_policy.distance(host) == HostDistance.IGNORED: + if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() @@ -1052,16 +1981,29 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): return with host.lock: - if (not host.is_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): - return + was_up = host.is_up + + # ignore down signals if we have open pools to the host + # this is to avoid closing pools when a control connection host became isolated + if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: + connected = False + for session in tuple(self.sessions): + pool_states = session.get_pool_state() + pool_state = pool_states.get(host) + if pool_state: + connected |= pool_state['open_count'] > 0 + if connected: + return host.set_down() + if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + return log.warning("Host %s has been marked down", host) - self.load_balancing_policy.on_down(host) + self.profile_manager.on_down(host) self.control_connection.on_down(host) - for session in self.sessions: + for session in tuple(self.sessions): session.on_down(host) for listener in self.listeners: @@ -1075,18 +2017,18 @@ def on_add(self, host, refresh_nodes=True): log.debug("Handling new host %r and notifying listeners", host) - distance = self.load_balancing_policy.distance(host) + distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) - self.load_balancing_policy.on_add(host) + self.profile_manager.on_add(host) self.control_connection.on_add(host, refresh_nodes) if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) - self._finalize_add(host) + self._finalize_add(host, set_up=False) return futures_lock = Lock() @@ -1118,7 +2060,7 @@ def future_completed(future): self._finalize_add(host) have_future = False - for session in self.sessions: + for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=True) if future is not None: have_future = True @@ -1128,14 +2070,15 @@ def future_completed(future): if not have_future: self._finalize_add(host) - def _finalize_add(self, host): - # mark the host as up and notify all listeners - host.set_up() + def _finalize_add(self, host, set_up=True): + if set_up: + host.set_up() + for listener in self.listeners: listener.on_add(host) # see if there are any pools to add or remove now that the host is marked up - for session in self.sessions: + for session in tuple(self.sessions): session.update_created_pools() def on_remove(self, host): @@ -1144,20 +2087,24 @@ def on_remove(self, host): log.debug("Removing host %s", host) host.set_down() - self.load_balancing_policy.on_remove(host) - for session in self.sessions: + self.profile_manager.on_remove(host) + for session in tuple(self.sessions): session.on_remove(host) for listener in self.listeners: listener.on_remove(host) self.control_connection.on_remove(host) + reconnection_handler = host.get_and_set_reconnection_handler(None) + if reconnection_handler: + reconnection_handler.cancel() + def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down - def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): + def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. @@ -1165,7 +2112,7 @@ def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nod the metadata. Intended for internal use only. """ - host, new = self.metadata.add_or_return_host(Host(address, self.conviction_policy_factory, datacenter, rack)) + host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack)) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) @@ -1205,8 +2152,8 @@ def _ensure_core_connections(self): If any host has fewer than the configured number of core connections open, attempt to open connections until that number is met. """ - for session in self.sessions: - for pool in session._pools.values(): + for session in tuple(self.sessions): + for pool in tuple(session._pools.values()): pool.ensure_core_connections() @staticmethod @@ -1231,6 +2178,14 @@ def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregat return SchemaTargetType.KEYSPACE return None + def get_control_connection_host(self): + """ + Returns the control connection host metadata. + """ + connection = self.control_connection._connection + endpoint = connection.endpoint if connection else None + return self.metadata.get_host(endpoint) if endpoint else None + def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ Synchronously refresh all schema metadata. @@ -1244,8 +2199,8 @@ def refresh_schema_metadata(self, max_schema_agreement_wait=None): An Exception is raised if schema refresh fails for any reason. """ - if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Schema metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Schema metadata was not refreshed. See log for details.") def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ @@ -1255,8 +2210,8 @@ def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, - schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Keyspace metadata was not refreshed. See log for details.") + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Keyspace metadata was not refreshed. See log for details.") def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ @@ -1265,8 +2220,9 @@ def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("Table metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("Table metadata was not refreshed. See log for details.") def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): """ @@ -1274,8 +2230,9 @@ def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreemen See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("View metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("View metadata was not refreshed. See log for details.") def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): """ @@ -1283,8 +2240,9 @@ def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_w See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Type metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Type metadata was not refreshed. See log for details.") def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ @@ -1294,8 +2252,9 @@ def refresh_user_function_metadata(self, keyspace, function, max_schema_agreemen See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Function metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Function metadata was not refreshed. See log for details.") def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ @@ -1305,20 +2264,25 @@ def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreem See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ - if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, schema_agreement_wait=max_schema_agreement_wait): - raise Exception("User Aggregate metadata was not refreshed. See log for details.") + if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, + schema_agreement_wait=max_schema_agreement_wait, force=True): + raise DriverException("User Aggregate metadata was not refreshed. See log for details.") - def refresh_nodes(self): + def refresh_nodes(self, force_token_rebuild=False): """ Synchronously refresh the node list and token metadata + `force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered. + An Exception is raised if node refresh fails for any reason. """ - if not self.control_connection.refresh_node_list_and_token_map(): - raise Exception("Node list was not refreshed. See log for details.") + if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): + raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): """ + *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead + Sets a flag to enable (True) or disable (False) all metadata refresh queries. This applies to both schema and node topology. @@ -1327,41 +2291,50 @@ def set_meta_refresh_enabled(self, enabled): Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ - self.control_connection.set_meta_refresh_enabled(bool(enabled)) + warn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0. Set " + "Cluster.schema_metadata_enabled and Cluster.token_metadata_enabled instead.", DeprecationWarning) + self.schema_metadata_enabled = enabled + self.token_metadata_enabled = enabled + + @classmethod + def _send_chunks(cls, connection, host, chunks, set_keyspace=False): + for ks_chunk in chunks: + messages = [PrepareMessage(query=s.query_string, + keyspace=s.keyspace if set_keyspace else None) + for s in ks_chunk] + # TODO: make this timeout configurable somehow? + responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) + for success, response in responses: + if not success: + log.debug("Got unexpected response when preparing " + "statement on host %s: %r", host, response) def _prepare_all_queries(self, host): - if not self._prepared_statements: + if not self._prepared_statements or not self.reprepare_on_up: return log.debug("Preparing all known prepared statements against host %s", host) connection = None try: - connection = self.connection_factory(host.address) - try: - self.control_connection.wait_for_schema_agreement(connection) - except Exception: - log.debug("Error waiting for schema agreement before preparing statements against host %s", host, exc_info=True) - - statements = self._prepared_statements.values() - for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): - if keyspace is not None: - connection.set_keyspace_blocking(keyspace) - - # prepare 10 statements at a time - ks_statements = list(ks_statements) + connection = self.connection_factory(host.endpoint) + statements = list(self._prepared_statements.values()) + if ProtocolVersion.uses_keyspace_flag(self.protocol_version): + # V5 protocol and higher, no need to set the keyspace chunks = [] - for i in range(0, len(ks_statements), 10): - chunks.append(ks_statements[i:i + 10]) - - for ks_chunk in chunks: - messages = [PrepareMessage(query=s.query_string) for s in ks_chunk] - # TODO: make this timeout configurable somehow? - responses = connection.wait_for_responses(*messages, timeout=5.0) - for response in responses: - if (not isinstance(response, ResultMessage) or - response.kind != RESULT_KIND_PREPARED): - log.debug("Got unexpected response when preparing " - "statement on host %s: %r", host, response) + for i in range(0, len(statements), 10): + chunks.append(statements[i:i + 10]) + self._send_chunks(connection, host, chunks, True) + else: + for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): + if keyspace is not None: + connection.set_keyspace_blocking(keyspace) + + # prepare 10 statements at a time + ks_statements = list(ks_statements) + chunks = [] + for i in range(0, len(ks_statements), 10): + chunks.append(ks_statements[i:i + 10]) + self._send_chunks(connection, host, chunks) log.debug("Done preparing all known prepared statements against host %s", host) except OperationTimedOut as timeout: @@ -1374,11 +2347,9 @@ def _prepare_all_queries(self, host): if connection: connection.close() - def prepare_on_all_sessions(self, query_id, prepared_statement, excluded_host): + def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement - for session in self.sessions: - session.prepare_on_all_hosts(prepared_statement.query_string, excluded_host) class Session(object): @@ -1403,55 +2374,101 @@ class Session(object): hosts = None keyspace = None is_shutdown = False + session_id = None + _monitor_reporter = None - row_factory = staticmethod(named_tuple_factory) - """ - The format to return row results in. By default, each - returned row will be a named tuple. You can alternatively - use any of the following: + _row_factory = staticmethod(named_tuple_factory) + @property + def row_factory(self): + """ + The format to return row results in. By default, each + returned row will be a named tuple. You can alternatively + use any of the following: - - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - - :func:`cassandra.query.dict_factory` - return a result row as a dict - - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict + - :func:`cassandra.query.tuple_factory` - return a result row as a tuple + - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple + - :func:`cassandra.query.dict_factory` - return a result row as a dict + - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict - """ + """ + return self._row_factory - default_timeout = 10.0 - """ - A default timeout, measured in seconds, for queries executed through - :meth:`.execute()` or :meth:`.execute_async()`. This default may be - overridden with the `timeout` parameter for either of those methods. + @row_factory.setter + def row_factory(self, rf): + self._validate_set_legacy_config('row_factory', rf) - Setting this to :const:`None` will cause no timeouts to be set by default. + _default_timeout = 10.0 - Please see :meth:`.ResponseFuture.result` for details on the scope and - effect of this timeout. + @property + def default_timeout(self): + """ + A default timeout, measured in seconds, for queries executed through + :meth:`.execute()` or :meth:`.execute_async()`. This default may be + overridden with the `timeout` parameter for either of those methods. - .. versionadded:: 2.0.0 - """ + Setting this to :const:`None` will cause no timeouts to be set by default. - default_consistency_level = ConsistencyLevel.LOCAL_ONE - """ - The default :class:`~ConsistencyLevel` for operations executed through - this session. This default may be overridden by setting the - :attr:`~.Statement.consistency_level` on individual statements. + Please see :meth:`.ResponseFuture.result` for details on the scope and + effect of this timeout. - .. versionadded:: 1.2.0 + .. versionadded:: 2.0.0 + """ + return self._default_timeout - .. versionchanged:: 3.0.0 + @default_timeout.setter + def default_timeout(self, timeout): + self._validate_set_legacy_config('default_timeout', timeout) - default changed from ONE to LOCAL_ONE - """ + _default_consistency_level = ConsistencyLevel.LOCAL_ONE - default_serial_consistency_level = None - """ - The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through - this session. This default may be overridden by setting the - :attr:`~.Statement.serial_consistency_level` on individual statements. + @property + def default_consistency_level(self): + """ + *Deprecated:* use execution profiles instead + The default :class:`~ConsistencyLevel` for operations executed through + this session. This default may be overridden by setting the + :attr:`~.Statement.consistency_level` on individual statements. - Only valid for ``protocol_version >= 2``. - """ + .. versionadded:: 1.2.0 + + .. versionchanged:: 3.0.0 + + default changed from ONE to LOCAL_ONE + """ + return self._default_consistency_level + + @default_consistency_level.setter + def default_consistency_level(self, cl): + """ + *Deprecated:* use execution profiles instead + """ + warn("Setting the consistency level at the session level will be removed in 4.0. Consider using " + "execution profiles and setting the desired consistency level to the EXEC_PROFILE_DEFAULT profile." + , DeprecationWarning) + self._validate_set_legacy_config('default_consistency_level', cl) + + _default_serial_consistency_level = None + + @property + def default_serial_consistency_level(self): + """ + The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through + this session. This default may be overridden by setting the + :attr:`~.Statement.serial_consistency_level` on individual statements. + + Only valid for ``protocol_version >= 2``. + """ + return self._default_serial_consistency_level + + @default_serial_consistency_level.setter + def default_serial_consistency_level(self, cl): + if (cl is not None and + not ConsistencyLevel.is_serial(cl)): + raise ValueError("default_serial_consistency_level must be either " + "ConsistencyLevel.SERIAL " + "or ConsistencyLevel.LOCAL_SERIAL.") + + self._validate_set_legacy_config('default_serial_consistency_level', cl) max_trace_wait = 2.0 """ @@ -1486,6 +2503,26 @@ class Session(object): .. versionadded:: 2.1.0 """ + timestamp_generator = None + """ + When :attr:`use_client_timestamp` is set, sessions call this object and use + the result as the timestamp. (Note that timestamps specified within a CQL + query will override this timestamp.) By default, a new + :class:`~.MonotonicTimestampGenerator` is created for + each :class:`Cluster` instance. + + Applications can set this value for custom timestamp behavior. For + example, an application could share a timestamp generator across + :class:`Cluster` objects to guarantee that the application will use unique, + increasing timestamps across clusters, or set it to to ``lambda: + int(time.time() * 1e6)`` if losing records over clock inconsistencies is + acceptable for the application. Custom :attr:`timestamp_generator` s should + be callable, and calling them should return an integer representing microseconds + since some point in time, typically UNIX epoch. + + .. versionadded:: 3.8.0 + """ + encoder = None """ A :class:`~cassandra.encoder.Encoder` instance that will be used when @@ -1523,34 +2560,82 @@ class Session(object): When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser` """ + session_id = None + """ + A UUID that uniquely identifies this Session to Insights. This will be + generated automatically. + """ + _lock = None _pools = None - _load_balancer = None + _profile_manager = None _metrics = None + _request_init_callbacks = None + _graph_paging_available = False - def __init__(self, cluster, hosts): + def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster self.hosts = hosts + self.keyspace = keyspace self._lock = RLock() self._pools = {} - self._load_balancer = cluster.load_balancing_policy + self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics + self._request_init_callbacks = [] self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() # create connection pools in parallel - futures = [] + self._initial_connect_futures = set() for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) - if future is not None: - futures.append(future) + if future: + self._initial_connect_futures.add(future) - for future in futures: - future.result() + futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + while futures.not_done and not any(f.result() for f in futures.done): + futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) - def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None): + if not any(f.result() for f in self._initial_connect_futures): + msg = "Unable to connect to any servers" + if self.keyspace: + msg += " using keyspace '%s'" % self.keyspace + raise NoHostAvailable(msg, [h.address for h in hosts]) + + self.session_id = uuid.uuid4() + self._graph_paging_available = self._check_graph_paging_available() + + if self.cluster.column_encryption_policy is not None: + try: + self.client_protocol_handler = type( + str(self.session_id) + "-ProtocolHandler", + (ProtocolHandler,), + {"column_encryption_policy": self.cluster.column_encryption_policy}) + except AttributeError: + log.info("Unable to set column encryption policy for session") + + if self.cluster.monitor_reporting_enabled: + cc_host = self.cluster.get_control_connection_host() + valid_insights_version = (cc_host and version_supports_insights(cc_host.dse_version)) + if valid_insights_version: + self._monitor_reporter = MonitorReporter( + interval_sec=self.cluster.monitor_reporting_interval, + session=self, + ) + else: + if cc_host: + log.debug('Not starting MonitorReporter thread for Insights; ' + 'not supported by server version {v} on ' + 'ControlConnection host {c}'.format(v=cc_host.release_version, c=cc_host)) + + log.debug('Started Session with client_id {} and session_id {}'.format(self.cluster.client_id, + self.session_id)) + + def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, + custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None, execute_as=None): """ Execute the given query and synchronously wait for the response. @@ -1566,9 +2651,8 @@ def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_ `timeout` should specify a floating-point timeout (in seconds) after which an :exc:`.OperationTimedOut` exception will be raised if the query - has not completed. If not set, the timeout defaults to - :attr:`~.Session.default_timeout`. If set to :const:`None`, there is - no timeout. Please see :meth:`.ResponseFuture.result` for details on + has not completed. If not set, the timeout defaults to the request_timeout of the selected ``execution_profile``. + If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. If `trace` is set to :const:`True`, the query will be sent with tracing enabled. @@ -1577,74 +2661,390 @@ def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_ `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. + + `execution_profile` is the execution profile to use for this request. It can be a key to a profile configured + via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`, + for example + + `paging_state` is an optional paging state, reused from a previous :class:`ResultSet`. + + `host` is the :class:`cassandra.pool.Host` that should handle the query. If the host specified is down or + not yet connected, the query will fail with :class:`NoHostAvailable`. Using this is + discouraged except in a few cases, e.g., querying node-local tables and applying schema changes. + + `execute_as` the user that will be used on the server to execute the request. This is only available + on a DSE cluster. """ - return self.execute_async(query, parameters, trace, custom_payload, timeout).result() - def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET): + return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host, execute_as).result() + + def execute_async(self, query, parameters=None, trace=False, custom_payload=None, + timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None, execute_as=None): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response delivery. You may also call :meth:`~.ResponseFuture.result()` - on the :class:`.ResponseFuture` to syncronously block for results at + on the :class:`.ResponseFuture` to synchronously block for results at any time. - If `trace` is set to :const:`True`, you may get the query trace descriptors using - :meth:`.ResponseFuture.get_query_trace()` or :meth:`.ResponseFuture.get_all_query_traces()` - on the future result. + See :meth:`Session.execute` for parameter definitions. + + Example usage:: + + >>> session = cluster.connect() + >>> future = session.execute_async("SELECT * FROM mycf") + + >>> def log_results(results): + ... for row in results: + ... log.info("Results: %s", row) + + >>> def log_error(exc): + >>> log.error("Operation failed: %s", exc) + + >>> future.add_callbacks(log_results, log_error) + + Async execution with blocking wait for results:: + + >>> future = session.execute_async("SELECT * FROM mycf") + >>> # do other stuff... + + >>> try: + ... results = future.result() + ... except Exception: + ... log.exception("Operation failed:") + + """ + custom_payload = custom_payload if custom_payload else {} + if execute_as: + custom_payload[_proxy_execute_key] = execute_as.encode() + + future = self._create_response_future( + query, parameters, trace, custom_payload, timeout, + execution_profile, paging_state, host) + future._protocol_handler = self.client_protocol_handler + self._on_request(future) + future.send_request() + return future + + def execute_concurrent(self, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + Executes a sequence of (statement, parameters) tuples concurrently. Each + ``parameters`` item must be a sequence or :const:`None`. + + The `concurrency` parameter controls how many statements will be executed + concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2, + it is recommended that this be kept below 100 times the number of + core connections per host times the number of connected hosts (see + :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded, + the event loop thread may attempt to block on new connection creation, + substantially impacting throughput. If :attr:`~.Cluster.protocol_version` + is 3 or higher, you can safely experiment with higher levels of concurrency. + + If `raise_on_first_error` is left as :const:`True`, execution will stop + after the first failed statement and the corresponding exception will be + raised. + + `results_generator` controls how the results are returned. + + * If :const:`False`, the results are returned only after all requests have completed. + * If :const:`True`, a generator expression is returned. Using a generator results in a constrained + memory footprint when the results set will be large -- results are yielded + as they return instead of materializing the entire list at once. The trade for lower memory + footprint is marginal CPU overhead (more thread coordination and sorting out-of-order results + on-the-fly). + + `execution_profile` argument is the execution profile to use for this + request, it is passed directly to :meth:`Session.execute_async`. + + A sequence of ``ExecutionResult(success, result_or_exc)`` namedtuples is returned + in the same order that the statements were passed in. If ``success`` is :const:`False`, + there was an error executing the statement, and ``result_or_exc`` + will be an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc`` + will be the query result. + + Example usage:: + + select_statement = session.prepare("SELECT * FROM users WHERE id=?") + + statements_and_params = [] + for user_id in user_ids: + params = (user_id, ) + statements_and_params.append((select_statement, params)) + + results = session.execute_concurrent(statements_and_params, raise_on_first_error=False) + + for (success, result) in results: + if not success: + handle_error(result) # result will be an Exception + else: + process_user(result[0]) # result will be a list of rows + + Note: in the case that `generators` are used, it is important to ensure the consumers do not + block or attempt further synchronous requests, because no further IO will be processed until + the consumer returns. This may also produce a deadlock in the IO event thread. + """ + from cassandra.concurrent import execute_concurrent + return execute_concurrent(self, statements_and_parameters, concurrency, raise_on_first_error, results_generator, execution_profile) + + def execute_concurrent_with_args(self, statement, parameters, *args, **kwargs): + """ + Like :meth:`~cassandra.concurrent.execute_concurrent()`, but takes a single + statement and a sequence of parameters. Each item in ``parameters`` + should be a sequence or :const:`None`. + + Example usage:: + + statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)") + parameters = [(x,) for x in range(1000)] + session.execute_concurrent_with_args(statement, parameters, concurrency=50) + """ + from cassandra.concurrent import execute_concurrent_with_args + return execute_concurrent_with_args(self, statement, parameters, *args, **kwargs) + + def execute_concurrent_async(self, statements_and_parameters, concurrency=100, raise_on_first_error=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + Asynchronously executes a sequence of (statement, parameters) tuples concurrently. + + Args: + session: Cassandra session object. + statement_and_parameters: Iterable of (prepared CQL statement, bind parameters) tuples. + concurrency (int, optional): Number of concurrent operations. Default is 100. + raise_on_first_error (bool, optional): If True, execution stops on the first error. Default is True. + execution_profile (ExecutionProfile, optional): Execution profile to use. Default is EXEC_PROFILE_DEFAULT. + + Returns: + A `Future` object that will be completed when all operations are done. + """ + from cassandra.concurrent import execute_concurrent_async + return execute_concurrent_async(self, statements_and_parameters, concurrency, raise_on_first_error, execution_profile) + + def execute_graph(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + """ + Executes a Gremlin query string or GraphStatement synchronously, + and returns a ResultSet from this execution. + + `parameters` is dict of named parameters to bind. The values must be + JSON-serializable. + + `execution_profile`: Selects an execution profile for the request. + + `execute_as` the user that will be used on the server to execute the request. + """ + return self.execute_graph_async(query, parameters, trace, execution_profile, execute_as).result() + + def execute_graph_async(self, query, parameters=None, trace=False, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, execute_as=None): + """ + Execute the graph query and return a :class:`ResponseFuture` + object which callbacks may be attached to for asynchronous response delivery. You may also call ``ResponseFuture.result()`` to synchronously block for + results at any time. + """ + if self.cluster._config_mode is _ConfigMode.LEGACY: + raise ValueError(("Cannot execute graph queries using Cluster legacy parameters. " + "Consider using Execution profiles: " + "https://docs.datastax.com/en/developer/python-driver/latest/execution_profiles/#execution-profiles")) + + if not isinstance(query, GraphStatement): + query = SimpleGraphStatement(query) + + # Clone and look up instance here so we can resolve and apply the extended attributes + execution_profile = self.execution_profile_clone_update(execution_profile) + + if not hasattr(execution_profile, 'graph_options'): + raise ValueError( + "Execution profile for graph queries must derive from GraphExecutionProfile, and provide graph_options") + + self._resolve_execution_profile_options(execution_profile) + + # make sure the graphson context row factory is binded to this cluster + try: + if issubclass(execution_profile.row_factory, _GraphSONContextRowFactory): + execution_profile.row_factory = execution_profile.row_factory(self.cluster) + except TypeError: + # issubclass might fail if arg1 is an instance + pass + + # set graph paging if needed + self._maybe_set_graph_paging(execution_profile) - `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. - If `query` is a Statement with its own custom_payload. The message payload - will be a union of the two, with the values specified here taking precedence. + graph_parameters = None + if parameters: + graph_parameters = self._transform_params(parameters, graph_options=execution_profile.graph_options) - If the server sends a custom payload in the response message, - the dict can be obtained following :meth:`.ResponseFuture.result` via - :attr:`.ResponseFuture.custom_payload` + custom_payload = execution_profile.graph_options.get_options_map() + if execute_as: + custom_payload[_proxy_execute_key] = execute_as.encode() + custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000)) - Example usage:: + future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload, + timeout=_NOT_SET, execution_profile=execution_profile) - >>> session = cluster.connect() - >>> future = session.execute_async("SELECT * FROM mycf") + future.message.query_params = graph_parameters + future._protocol_handler = self.client_protocol_handler - >>> def log_results(results): - ... for row in results: - ... log.info("Results: %s", row) + if execution_profile.graph_options.is_analytics_source and \ + isinstance(execution_profile.load_balancing_policy, DefaultLoadBalancingPolicy): + self._target_analytics_master(future) + else: + future.send_request() + return future - >>> def log_error(exc): - >>> log.error("Operation failed: %s", exc) + def _maybe_set_graph_paging(self, execution_profile): + graph_paging = execution_profile.continuous_paging_options + if execution_profile.continuous_paging_options is _NOT_SET: + graph_paging = ContinuousPagingOptions() if self._graph_paging_available else None - >>> future.add_callbacks(log_results, log_error) + execution_profile.continuous_paging_options = graph_paging - Async execution with blocking wait for results:: + def _check_graph_paging_available(self): + """Verify if we can enable graph paging. This executed only once when the session is created.""" - >>> future = session.execute_async("SELECT * FROM mycf") - >>> # do other stuff... + if not ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version): + return False - >>> try: - ... results = future.result() - ... except Exception: - ... log.exception("Operation failed:") + for host in self.cluster.metadata.all_hosts(): + if host.dse_version is None: + return False + + version = Version(host.dse_version) + if version < _GRAPH_PAGING_MIN_DSE_VERSION: + return False + + return True + def _resolve_execution_profile_options(self, execution_profile): """ - if timeout is _NOT_SET: - timeout = self.default_timeout + Determine the GraphSON protocol and row factory for a graph query. This is useful + to configure automatically the execution profile when executing a query on a + core graph. + + If `graph_protocol` is not explicitly specified, the following rules apply: + - Default to GraphProtocol.GRAPHSON_1_0, or GRAPHSON_2_0 if the `graph_language` is not gremlin-groovy. + - If `graph_options.graph_name` is specified and is a Core graph, set GraphSON_3_0. + If `row_factory` is not explicitly specified, the following rules apply: + - Default to graph_object_row_factory. + - If `graph_options.graph_name` is specified and is a Core graph, set graph_graphson3_row_factory. + """ + if execution_profile.graph_options.graph_protocol is not None and \ + execution_profile.row_factory is not None: + return - future = self._create_response_future(query, parameters, trace, custom_payload, timeout) - future._protocol_handler = self.client_protocol_handler - future.send_request() - return future + graph_options = execution_profile.graph_options + + is_core_graph = False + if graph_options.graph_name: + # graph_options.graph_name is bytes ... + name = graph_options.graph_name.decode('utf-8') + if name in self.cluster.metadata.keyspaces: + ks_metadata = self.cluster.metadata.keyspaces[name] + if ks_metadata.graph_engine == 'Core': + is_core_graph = True + + if is_core_graph: + graph_protocol = GraphProtocol.GRAPHSON_3_0 + row_factory = graph_graphson3_row_factory + else: + if graph_options.graph_language == GraphOptions.DEFAULT_GRAPH_LANGUAGE: + graph_protocol = GraphOptions.DEFAULT_GRAPH_PROTOCOL + row_factory = graph_object_row_factory + else: + # if not gremlin-groovy, GraphSON_2_0 + graph_protocol = GraphProtocol.GRAPHSON_2_0 + row_factory = graph_graphson2_row_factory + + # Only apply if not set explicitly + if graph_options.graph_protocol is None: + graph_options.graph_protocol = graph_protocol + if execution_profile.row_factory is None: + execution_profile.row_factory = row_factory + + def _transform_params(self, parameters, graph_options): + if not isinstance(parameters, dict): + raise ValueError('The parameters must be a dictionary. Unnamed parameters are not allowed.') + + # Serialize python types to graphson + serializer = GraphSON1Serializer + if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0: + serializer = GraphSON2Serializer() + elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: + # only required for core graphs + context = { + 'cluster': self.cluster, + 'graph_name': graph_options.graph_name.decode('utf-8') if graph_options.graph_name else None + } + serializer = GraphSON3Serializer(context) + + serialized_parameters = serializer.serialize(parameters) + return [json.dumps(serialized_parameters).encode('utf-8')] + + def _target_analytics_master(self, future): + future._start_timer() + master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()", + parameters=None, trace=False, + custom_payload=None, timeout=future.timeout) + master_query_future.row_factory = tuple_factory + master_query_future.send_request() + + cb = self._on_analytics_master_result + args = (master_query_future, future) + master_query_future.add_callbacks(callback=cb, callback_args=args, errback=cb, errback_args=args) + + def _on_analytics_master_result(self, response, master_future, query_future): + try: + row = master_future.result()[0] + addr = row[0]['location'] + delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided + if delimiter_index > 0: + addr = addr[:delimiter_index] + targeted_query = HostTargetingStatement(query_future.query, addr) + query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + except Exception: + log.debug("Failed querying analytics master (request might not be routed optimally). " + "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) + + self.submit(query_future.send_request) - def _create_response_future(self, query, parameters, trace, custom_payload, timeout): + def _create_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None - if isinstance(query, six.string_types): + if isinstance(query, str): query = SimpleStatement(query) elif isinstance(query, PreparedStatement): query = query.bind(parameters) - cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level - serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level + if self.cluster._config_mode == _ConfigMode.LEGACY: + if execution_profile is not EXEC_PROFILE_DEFAULT: + raise ValueError("Cannot specify execution_profile while using legacy parameters.") + + if timeout is _NOT_SET: + timeout = self.default_timeout + + cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level + serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level + + retry_policy = query.retry_policy or self.cluster.default_retry_policy + row_factory = self.row_factory + load_balancing_policy = self.cluster.load_balancing_policy + spec_exec_policy = None + continuous_paging_options = None + else: + execution_profile = self._maybe_get_execution_profile(execution_profile) + + if timeout is _NOT_SET: + timeout = execution_profile.request_timeout + + cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level + serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level + continuous_paging_options = execution_profile.continuous_paging_options + + retry_policy = query.retry_policy or execution_profile.retry_policy + row_factory = execution_profile.row_factory + load_balancing_policy = execution_profile.load_balancing_policy + spec_exec_policy = execution_profile.speculative_execution_policy fetch_size = query.fetch_size if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: @@ -1652,44 +3052,130 @@ def _create_response_future(self, query, parameters, trace, custom_payload, time elif self._protocol_version == 1: fetch_size = None + start_time = time.time() if self._protocol_version >= 3 and self.use_client_timestamp: - timestamp = int(time.time() * 1e6) + timestamp = self.cluster.timestamp_generator() else: timestamp = None + supports_continuous_paging_state = ( + ProtocolVersion.has_continuous_paging_next_pages(self._protocol_version) + ) + if continuous_paging_options and supports_continuous_paging_state: + continuous_paging_state = ContinuousPagingState(continuous_paging_options.max_queue_size) + else: + continuous_paging_state = None + if isinstance(query, SimpleStatement): query_string = query.query_string + statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None if parameters: query_string = bind_params(query_string, parameters, self.encoder) message = QueryMessage( query_string, cl, serial_cl, - fetch_size, timestamp=timestamp) + fetch_size, paging_state, timestamp, + continuous_paging_options, statement_keyspace) elif isinstance(query, BoundStatement): - message = ExecuteMessage( - query.prepared_statement.query_id, query.values, cl, - serial_cl, fetch_size, - timestamp=timestamp) prepared_statement = query.prepared_statement + message = ExecuteMessage( + prepared_statement.query_id, query.values, cl, + serial_cl, fetch_size, paging_state, timestamp, + skip_meta=bool(prepared_statement.result_metadata), + continuous_paging_options=continuous_paging_options, + result_metadata_id=prepared_statement.result_metadata_id) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( "BatchStatement execution is only supported with protocol version " "2 or higher (supported in Cassandra 2.0 and higher). Consider " "setting Cluster.protocol_version to 2 to support this operation.") + statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None message = BatchMessage( query.batch_type, query._statements_and_parameters, cl, - serial_cl, timestamp) + serial_cl, timestamp, statement_keyspace) + elif isinstance(query, GraphStatement): + # the statement_keyspace is not aplicable to GraphStatement + message = QueryMessage(query.query, cl, serial_cl, fetch_size, + paging_state, timestamp, + continuous_paging_options) message.tracing = trace - message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) + message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version + spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None return ResponseFuture( self, message, query, timeout, metrics=self._metrics, - prepared_statement=prepared_statement) + prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, + load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan, + continuous_paging_state=continuous_paging_state, host=host) + + def get_execution_profile(self, name): + """ + Returns the execution profile associated with the provided ``name``. + + :param name: The name (or key) of the execution profile. + """ + profiles = self.cluster.profile_manager.profiles + try: + return profiles[name] + except KeyError: + eps = [_execution_profile_to_string(ep) for ep in profiles.keys()] + raise ValueError("Invalid execution_profile: %s; valid profiles are: %s." % ( + _execution_profile_to_string(name), ', '.join(eps))) + + def _maybe_get_execution_profile(self, ep): + return ep if isinstance(ep, ExecutionProfile) else self.get_execution_profile(ep) + + def execution_profile_clone_update(self, ep, **kwargs): + """ + Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes + of the returned profile. + + This is a shallow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy + is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen + by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating + the shared object. + """ + clone = copy(self._maybe_get_execution_profile(ep)) + for attr, value in kwargs.items(): + setattr(clone, attr, value) + return clone + + def add_request_init_listener(self, fn, *args, **kwargs): + """ + Adds a callback with arguments to be called when any request is created. + + It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created, + and before the request is sent. This can be used to create extensions by adding result callbacks to the + response future. + + `response_future` is the :class:`.ResponseFuture` for the request. + + Note that the init callback is done on the client thread creating the request, so you may need to consider + synchronization if you have multiple threads. Any callbacks added to the response future will be executed + on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in + :meth:`.ResponseFuture.add_callbacks`. - def prepare(self, query, custom_payload=None): + See `this example `_ in the + source tree for an example. + """ + self._request_init_callbacks.append((fn, args, kwargs)) + + def remove_request_init_listener(self, fn, *args, **kwargs): + """ + Removes a callback and arguments from the list. + + See :meth:`.Session.add_request_init_listener`. + """ + self._request_init_callbacks.remove((fn, args, kwargs)) + + def _on_request(self, response_future): + for fn, args, kwargs in self._request_init_callbacks: + fn(response_future, *args, **kwargs) + + def prepare(self, query, custom_payload=None, keyspace=None): """ Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: @@ -1712,43 +3198,59 @@ def prepare(self, query, custom_payload=None): ... bound = prepared.bind((user.id, user.name, user.age)) ... session.execute(bound) + Alternatively, if :attr:`~.Cluster.protocol_version` is 5 or higher + (requires Cassandra 4.0+), the keyspace can be specified as a + parameter. This will allow you to avoid specifying the keyspace in the + query without specifying a keyspace in :meth:`~.Cluster.connect`. It + even will let you prepare and use statements against a keyspace other + than the one originally specified on connection: + + >>> analyticskeyspace_prepared = session.prepare( + ... "INSERT INTO user_activity id, last_activity VALUES (?, ?)", + ... keyspace="analyticskeyspace") # note the different keyspace + **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ - message = PrepareMessage(query=query) + message = PrepareMessage(query=query, keyspace=keyspace) future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() - query_id, column_metadata, pk_indexes = future.result() + response = future.result().one() except Exception: log.exception("Error preparing query:") raise + prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( - query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, - self._protocol_version) + response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, + self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy) prepared_statement.custom_payload = future.custom_payload - host = future._current_host - try: - self.cluster.prepare_on_all_sessions(query_id, prepared_statement, host) - except Exception: - log.exception("Error preparing query on all hosts:") + self.cluster.add_prepared(response.query_id, prepared_statement) + + if self.cluster.prepare_on_all_hosts: + host = future._current_host + try: + self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace) + except Exception: + log.exception("Error preparing query on all hosts:") return prepared_statement - def prepare_on_all_hosts(self, query, excluded_host): + def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. """ futures = [] - for host in self._pools.keys(): + for host in tuple(self._pools.keys()): if host != excluded_host and host.is_up: - future = ResponseFuture(self, PrepareMessage(query=query), None, self.default_timeout) + future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), + None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared @@ -1760,7 +3262,7 @@ def prepare_on_all_hosts(self, query, excluded_host): continue if request_id is None: - # the error has already been logged by ResponsFuture + # the error has already been logged by ResponseFuture log.debug("Failed to prepare query for host %s: %r", host, future._errors.get(host)) continue @@ -1784,14 +3286,39 @@ def shutdown(self): else: self.is_shutdown = True - for pool in self._pools.values(): + # PYTHON-673. If shutdown was called shortly after session init, avoid + # a race by cancelling any initial connection attempts haven't started, + # then blocking on any that have. + for future in self._initial_connect_futures: + future.cancel() + wait_futures(self._initial_connect_futures) + + if self._monitor_reporter: + self._monitor_reporter.stop() + + for pool in tuple(self._pools.values()): pool.shutdown() + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown() + + def __del__(self): + try: + # Ensure all connections are closed, in case the Session object is deleted by the GC + self.shutdown() + except: + # Ignore all errors. Shutdown errors can be caught by the user + # when cluster.shutdown() is called explicitly. + pass + def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ - distance = self._load_balancer.distance(host) + distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None @@ -1800,9 +3327,10 @@ def run_add_or_renew_pool(): if self._protocol_version >= 3: new_pool = HostConnection(host, distance, self) else: + # TODO remove host pool again ??? new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: - conn_exc = ConnectionException(str(auth_exc), host=host) + conn_exc = ConnectionException(str(auth_exc), endpoint=host) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: @@ -1815,7 +3343,27 @@ def run_add_or_renew_pool(): return False previous = self._pools.get(host) - self._pools[host] = new_pool + with self._lock: + while new_pool._keyspace != self.keyspace: + self._lock.release() + set_keyspace_event = Event() + errors_returned = [] + + def callback(pool, errors): + errors_returned.extend(errors) + set_keyspace_event.set() + + new_pool._set_keyspace_for_all_conns(self.keyspace, callback) + set_keyspace_event.wait(self.cluster.connect_timeout) + if not set_keyspace_event.is_set() or errors_returned: + log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) + self.cluster.on_down(host, is_host_addition) + new_pool.shutdown() + self._lock.acquire() + return False + self._lock.acquire() + self._pools[host] = new_pool + log.debug("Added pool for host %s to session", host) if previous: previous.shutdown() @@ -1844,19 +3392,26 @@ def update_created_pools(self): For internal use only. """ + futures = set() for host in self.cluster.metadata.all_hosts(): - distance = self._load_balancer.distance(host) + distance = self._profile_manager.distance(host) pool = self._pools.get(host) - + future = None if not pool or pool.is_shutdown: - if distance != HostDistance.IGNORED and host.is_up: - self.add_or_renew_pool(host, False) + # we don't eagerly set is_up on previously ignored hosts. None is included here + # to allow us to attempt connections to hosts that have gone from ignored to something + # else. + if distance != HostDistance.IGNORED and host.is_up in (True, None): + future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: - self.remove_pool(host) + future = self.remove_pool(host) else: pool.host_distance = distance + if future: + futures.add(future) + return futures def on_down(self, host): """ @@ -1885,9 +3440,9 @@ def _set_keyspace_for_all_pools(self, keyspace, callback): called with a dictionary of all errors that occurred, keyed by the `Host` that they occurred against. """ - self.keyspace = keyspace - - remaining_callbacks = set(self._pools.values()) + with self._lock: + self.keyspace = keyspace + remaining_callbacks = set(self._pools.values()) errors = {} if not remaining_callbacks: @@ -1902,7 +3457,7 @@ def pool_finished_setting_keyspace(pool, host_errors): if not remaining_callbacks: callback(host_errors) - for pool in self._pools.values(): + for pool in tuple(self._pools.values()): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) def user_type_registered(self, keyspace, user_type, klass): @@ -1924,10 +3479,6 @@ def user_type_registered(self, keyspace, user_type, klass): 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) field_names = type_meta.field_names - if six.PY2: - # go from unicode to string to avoid decode errors from implicit - # decode when formatting non-ascii values - field_names = [fn.encode('utf-8') for fn in field_names] def encode(val): return '{ %s }' % ' , '.join('%s : %s' % ( @@ -1943,11 +3494,17 @@ def submit(self, fn, *args, **kwargs): return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): - return dict((host, pool.get_state()) for host, pool in self._pools.items()) + return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) def get_pools(self): return self._pools.values() + def _validate_set_legacy_config(self, attr_name, value): + if self.cluster._config_mode == _ConfigMode.PROFILES: + raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) + setattr(self, '_' + attr_name, value) + self.cluster._config_mode = _ConfigMode.LEGACY + class UserTypeDoesNotExist(Exception): """ @@ -2009,27 +3566,49 @@ class ControlConnection(object): Internal """ - _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address, schema_version FROM system.peers" - _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + _SELECT_PEERS = "SELECT * FROM system.peers" + _SELECT_PEERS_NO_TOKENS_TEMPLATE = "SELECT host_id, peer, data_center, rack, rpc_address, {nt_col_name}, release_version, schema_version FROM system.peers" + _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" + _SELECT_LOCAL_NO_TOKENS = "SELECT host_id, cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" + # Used only when token_metadata_enabled is set to False + _SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS = "SELECT rpc_address FROM system.local WHERE key='local'" - _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" + _SELECT_SCHEMA_PEERS_TEMPLATE = "SELECT peer, host_id, {nt_col_name}, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" + _SELECT_PEERS_V2 = "SELECT * FROM system.peers_v2" + _SELECT_PEERS_NO_TOKENS_V2 = "SELECT host_id, peer, peer_port, data_center, rack, native_address, native_port, release_version, schema_version FROM system.peers_v2" + _SELECT_SCHEMA_PEERS_V2 = "SELECT host_id, peer, peer_port, native_address, native_port, schema_version FROM system.peers_v2" + + _MINIMUM_NATIVE_ADDRESS_DSE_VERSION = Version("6.0.0") + + class PeersQueryType(object): + """internal Enum for _peers_query""" + PEERS = 0 + PEERS_SCHEMA = 1 + _is_shutdown = False _timeout = None _protocol_version = None _schema_event_refresh_window = None _topology_event_refresh_window = None + _status_event_refresh_window = None - _meta_refresh_enabled = True + _schema_meta_enabled = True + _token_meta_enabled = True + + _uses_peers_v2 = True # for testing purposes _time = time def __init__(self, cluster, timeout, schema_event_refresh_window, - topology_event_refresh_window): + topology_event_refresh_window, + status_event_refresh_window, + schema_meta_enabled=True, + token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) @@ -2038,6 +3617,9 @@ def __init__(self, cluster, timeout, self._schema_event_refresh_window = schema_event_refresh_window self._topology_event_refresh_window = topology_event_refresh_window + self._status_event_refresh_window = status_event_refresh_window + self._schema_meta_enabled = schema_meta_enabled + self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() @@ -2054,6 +3636,8 @@ def connect(self): self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) + self._cluster.metadata.dbaas = self._connection._product_type == dscloud.DATASTAX_CLOUD_PRODUCT_TYPE + def _set_new_connection(self, conn): """ Replace existing connection (if there is one) and close it. @@ -2076,16 +3660,24 @@ def _reconnect_internal(self): a connection to that host. """ errors = {} - for host in self._cluster.load_balancing_policy.make_query_plan(): + lbp = ( + self._cluster.load_balancing_policy + if self._cluster._config_mode == _ConfigMode.LEGACY else + self._cluster._default_load_balancing_policy + ) + + for host in lbp.make_query_plan(): try: return self._try_connect(host) except ConnectionException as exc: - errors[host.address] = exc + errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: - errors[host.address] = exc + errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) + if self._is_shutdown: + raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) @@ -2098,10 +3690,21 @@ def _try_connect(self, host): while True: try: - connection = self._cluster.connection_factory(host.address, is_control_connection=True) + connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True) + if self._is_shutdown: + connection.close() + raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: - self._cluster.protocol_downgrade(host.address, e.startup_version) + self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + except ProtocolException as e: + # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver + # protocol version. If the protocol version was not explicitly specified, + # and that the server raises a beta protocol error, we should downgrade. + if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: + self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + else: + raise log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", @@ -2111,7 +3714,7 @@ def _try_connect(self, host): # _clear_watcher will be called when this ControlConnection is about to be finalized # _watch_callback will get the actual callback from the Connection and relay it to # this object (after a dereferencing a weakref) - self_weakref = weakref.ref(self, callback=partial(_clear_watcher, weakref.proxy(connection))) + self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: connection.register_watchers({ "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), @@ -2119,11 +3722,25 @@ def _try_connect(self, host): "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') }, register_timeout=self._timeout) - peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=ConsistencyLevel.ONE) - local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=ConsistencyLevel.ONE) - shared_results = connection.wait_for_responses( - peers_query, local_query, timeout=self._timeout) - + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS + peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) + local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) + (peers_success, peers_result), (local_success, local_result) = connection.wait_for_responses( + peers_query, local_query, timeout=self._timeout, fail_on_error=False) + + if not local_success: + raise local_result + + if not peers_success: + # error with the peers v2 query, fallback to peers v1 + self._uses_peers_v2 = False + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) + peers_result = connection.wait_for_response( + peers_query, timeout=self._timeout) + + shared_results = (peers_result, local_result) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: @@ -2144,7 +3761,7 @@ def _reconnect(self): self._set_new_connection(self._reconnect_internal()) except NoHostAvailable: # make a retry schedule (which includes backoff) - schedule = self.cluster.reconnection_policy.new_schedule() + schedule = self._cluster.reconnection_policy.new_schedule() with self._reconnection_lock: @@ -2198,16 +3815,12 @@ def shutdown(self): log.debug("Shutting down control connection") if self._connection: self._connection.close() - del self._connection - - def refresh_schema(self, **kwargs): - if not self._meta_refresh_enabled: - log.debug("[control connection] Skipping schema refresh because meta refresh is disabled") - return False + self._connection = None + def refresh_schema(self, force=False, **kwargs): try: if self._connection: - return self._refresh_schema(self._connection, **kwargs) + return self._refresh_schema(self._connection, force=force, **kwargs) except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: @@ -2215,13 +3828,18 @@ def refresh_schema(self, **kwargs): self._signal_error() return False - def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, **kwargs): + def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): if self._cluster.is_shutdown: return False agreed = self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) + + if not self._schema_meta_enabled and not force: + log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") + return False + if not agreed: log.debug("Skipping schema refresh due to lack of schema agreement") return False @@ -2231,10 +3849,6 @@ def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_w return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): - if not self._meta_refresh_enabled: - log.debug("[control connection] Skipping node list refresh because meta refresh is disabled") - return False - try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) @@ -2248,86 +3862,144 @@ def refresh_node_list_and_token_map(self, force_token_rebuild=False): def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, force_token_rebuild=False): - if preloaded_results: log.debug("[control connection] Refreshing node list and token map using preloaded results") peers_result = preloaded_results[0] local_result = preloaded_results[1] else: - log.debug("[control connection] Refreshing node list and token map") cl = ConsistencyLevel.ONE - peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl) + sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) + if not self._token_meta_enabled: + log.debug("[control connection] Refreshing node list without token map") + sel_local = self._SELECT_LOCAL_NO_TOKENS + else: + log.debug("[control connection] Refreshing node list and token map") + sel_local = self._SELECT_LOCAL + peers_query = QueryMessage(query=sel_peers, consistency_level=cl) + local_query = QueryMessage(query=sel_local, consistency_level=cl) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) - peers_result = dict_factory(*peers_result.results) + peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) partitioner = None token_map = {} - if local_result.results: - local_rows = dict_factory(*(local_result.results)) + found_hosts = set() + if local_result.parsed_rows: + found_hosts.add(connection.endpoint) + local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name - host = self._cluster.metadata.get_host(connection.host) + partitioner = local_row.get("partitioner") + tokens = local_row.get("tokens") + + host = self._cluster.metadata.get_host(connection.endpoint) if host: datacenter = local_row.get("data_center") rack = local_row.get("rack") self._update_location_info(host, datacenter, rack) + host.host_id = local_row.get("host_id") + host.listen_address = local_row.get("listen_address") + host.listen_port = local_row.get("listen_port") + host.broadcast_address = _NodeInfo.get_broadcast_address(local_row) + host.broadcast_port = _NodeInfo.get_broadcast_port(local_row) + + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row) + if host.broadcast_rpc_address is None: + if self._token_meta_enabled: + # local rpc_address is not available, use the connection endpoint + host.broadcast_rpc_address = connection.endpoint.address + host.broadcast_rpc_port = connection.endpoint.port + else: + # local rpc_address has not been queried yet, try to fetch it + # separately, which might fail because C* < 2.1.6 doesn't have rpc_address + # in system.local. See CASSANDRA-9436. + local_rpc_address_query = QueryMessage(query=self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, + consistency_level=ConsistencyLevel.ONE) + success, local_rpc_address_result = connection.wait_for_response( + local_rpc_address_query, timeout=self._timeout, fail_on_error=False) + if success: + row = dict_factory( + local_rpc_address_result.column_names, + local_rpc_address_result.parsed_rows) + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0]) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0]) + else: + host.broadcast_rpc_address = connection.endpoint.address + host.broadcast_rpc_port = connection.endpoint.port - partitioner = local_row.get("partitioner") - tokens = local_row.get("tokens") - if partitioner and tokens: - token_map[host] = tokens + host.release_version = local_row.get("release_version") + host.dse_version = local_row.get("dse_version") + host.dse_workload = local_row.get("workload") + host.dse_workloads = local_row.get("workloads") - connection.server_version = local_row['release_version'] + if partitioner and tokens: + token_map[host] = tokens # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None - found_hosts = set() for row in peers_result: - addr = row.get("rpc_address") + if not self._is_valid_peer(row): + log.warning( + "Found an invalid row for peer (%s). Ignoring host." % + _NodeInfo.get_broadcast_rpc_address(row)) + continue - if not addr or addr in ["0.0.0.0", "::"]: - addr = row.get("peer") + endpoint = self._cluster.endpoint_factory.create(row) - tokens = row.get("tokens") - if not tokens: - log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) + if endpoint in found_hosts: + log.warning("Found multiple hosts with the same endpoint (%s). Excluding peer %s", endpoint, row.get("peer")) continue - found_hosts.add(addr) + found_hosts.add(endpoint) - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") if host is None: - log.debug("[control connection] Found new host to connect to: %s", addr) - host, _ = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) + log.debug("[control connection] Found new host to connect to: %s", endpoint) + host, _ = self._cluster.add_host(endpoint, datacenter, rack, signal=True, refresh_nodes=False) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) - if partitioner and tokens: + host.host_id = row.get("host_id") + host.broadcast_address = _NodeInfo.get_broadcast_address(row) + host.broadcast_port = _NodeInfo.get_broadcast_port(row) + host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row) + host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row) + host.release_version = row.get("release_version") + host.dse_version = row.get("dse_version") + host.dse_workload = row.get("workload") + host.dse_workloads = row.get("workloads") + + tokens = row.get("tokens", None) + if partitioner and tokens and self._token_meta_enabled: token_map[host] = tokens for old_host in self._cluster.metadata.all_hosts(): - if old_host.address != connection.host and old_host.address not in found_hosts: + if old_host.endpoint.address != connection.endpoint and old_host.endpoint not in found_hosts: should_rebuild_token_map = True - if old_host.address not in self._cluster.contact_points: - log.debug("[control connection] Found host that has been removed: %r", old_host) - self._cluster.remove_host(old_host) + log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) + self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) + @staticmethod + def _is_valid_peer(row): + return bool(_NodeInfo.get_broadcast_rpc_address(row) and row.get("host_id") and + row.get("data_center") and row.get("rack") and + ('tokens' not in row or row.get('tokens'))) + def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False @@ -2335,9 +4007,9 @@ def _update_location_info(self, host, datacenter, rack): # If the dc/rack information changes, we need to update the load balancing policy. # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes # that the policy will update correctly, but in practice this should work. - self._cluster.load_balancing_policy.on_down(host) + self._cluster.profile_manager.on_down(host) host.set_location_info(datacenter, rack) - self._cluster.load_balancing_policy.on_up(host) + self._cluster.profile_manager.on_up(host) return True def _delay_for_event_type(self, event_type, delay_window): @@ -2354,23 +4026,31 @@ def _delay_for_event_type(self, event_type, delay_window): self._event_schedule_times[event_type] = this_time return delay + def _refresh_nodes_if_not_up(self, host): + """ + Used to mitigate refreshes for nodes that are already known. + Some versions of the server send superfluous NEW_NODE messages in addition to UP events. + """ + if not host or not host.is_up: + self.refresh_node_list_and_token_map() + def _handle_topology_change(self, event): change_type = event["change_type"] addr, port = event["address"] + host = self._cluster.metadata.get_host(addr, port) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) - self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) + self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, host) elif change_type == "REMOVED_NODE": - host = self._cluster.metadata.get_host(addr) self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] addr, port = event["address"] - host = self._cluster.metadata.get_host(addr) + host = self._cluster.metadata.get_host(addr, port) if change_type == "UP": - delay = 1 + self._delay_for_event_type('status_change', 0.5) # randomness to avoid thundering herd problem on events + delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) @@ -2379,7 +4059,7 @@ def _handle_status_change(self, event): elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. - # But it is unlikely, and don't have too much consequence since we'll try reconnecting + # This is unlikely, and will not have much consequence because we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler @@ -2413,7 +4093,7 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai peers_result = preloaded_results[0] local_result = preloaded_results[1] - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) if schema_mismatches is None: return True @@ -2422,8 +4102,10 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai elapsed = 0 cl = ConsistencyLevel.ONE schema_mismatches = None + select_peers_query = self._get_peers_query(self.PeersQueryType.PEERS_SCHEMA, connection) + while elapsed < total_timeout: - peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) + peers_query = QueryMessage(query=select_peers_query, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) try: timeout = min(self._timeout, total_timeout - elapsed) @@ -2441,7 +4123,7 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai else: raise - schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) + schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.endpoint) if schema_mismatches is None: return True @@ -2450,15 +4132,15 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", - connection.host, schema_mismatches) + connection.endpoint, schema_mismatches) return False def _get_schema_mismatches(self, peers_result, local_result, local_address): - peers_result = dict_factory(*peers_result.results) + peers_result = dict_factory(peers_result.column_names, peers_result.parsed_rows) versions = defaultdict(set) - if local_result.results: - local_row = dict_factory(*local_result.results)[0] + if local_result.parsed_rows: + local_row = dict_factory(local_result.column_names, local_result.parsed_rows)[0] if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) @@ -2466,20 +4148,61 @@ def _get_schema_mismatches(self, peers_result, local_result, local_address): schema_ver = row.get('schema_version') if not schema_ver: continue - - addr = row.get("rpc_address") - if not addr or addr in ["0.0.0.0", "::"]: - addr = row.get("peer") - - peer = self._cluster.metadata.get_host(addr) - if peer and peer.is_up: - versions[schema_ver].add(addr) + endpoint = self._cluster.endpoint_factory.create(row) + peer = self._cluster.metadata.get_host(endpoint) + if peer and peer.is_up is not False: + versions[schema_ver].add(endpoint) if len(versions) == 1: log.debug("[control connection] Schemas match") return None - return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) + return dict((version, list(nodes)) for version, nodes in versions.items()) + + def _get_peers_query(self, peers_query_type, connection=None): + """ + Determine the peers query to use. + + :param peers_query_type: Should be one of PeersQueryType enum. + + If _uses_peers_v2 is True, return the proper peers_v2 query (no templating). + Else, apply the logic below to choose the peers v1 address column name: + + Given a connection: + + - find the server product version running on the connection's host, + - use that to choose the column name for the transport address (see APOLLO-1130), and + - use that column name in the provided peers query template. + """ + if peers_query_type not in (self.PeersQueryType.PEERS, self.PeersQueryType.PEERS_SCHEMA): + raise ValueError("Invalid peers query type: %s" % peers_query_type) + + if self._uses_peers_v2: + if peers_query_type == self.PeersQueryType.PEERS: + query = self._SELECT_PEERS_V2 if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS_V2 + else: + query = self._SELECT_SCHEMA_PEERS_V2 + else: + if peers_query_type == self.PeersQueryType.PEERS and self._token_meta_enabled: + query = self._SELECT_PEERS + else: + query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE + if peers_query_type == self.PeersQueryType.PEERS_SCHEMA + else self._SELECT_PEERS_NO_TOKENS_TEMPLATE) + + host_release_version = self._cluster.metadata.get_host(connection.endpoint).release_version + host_dse_version = self._cluster.metadata.get_host(connection.endpoint).dse_version + uses_native_address_query = ( + host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION) + + if uses_native_address_query: + query = query_template.format(nt_col_name="native_transport_address") + elif host_release_version: + query = query_template.format(nt_col_name="rpc_address") + else: + query = self._SELECT_PEERS + + return query def _signal_error(self): with self._lock: @@ -2489,7 +4212,7 @@ def _signal_error(self): # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.host) + host = self._cluster.metadata.get_host(self._connection.endpoint) # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: @@ -2507,7 +4230,7 @@ def on_up(self, host): def on_down(self, host): conn = self._connection - if conn and conn.host == host.address and \ + if conn and conn.endpoint == host.endpoint and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) @@ -2519,7 +4242,13 @@ def on_add(self, host, refresh_nodes=True): self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): - self.refresh_node_list_and_token_map(force_token_rebuild=True) + c = self._connection + if c and c.endpoint == host.endpoint: + log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) + # refresh will be done on reconnect + self.reconnect() + else: + self.refresh_node_list_and_token_map(force_token_rebuild=True) def get_connections(self): c = getattr(self, '_connection', None) @@ -2529,9 +4258,6 @@ def return_connection(self, connection): if connection is self._connection and (connection.is_defunct or connection.is_closed): self.reconnect() - def set_meta_refresh_enabled(self, enabled): - self._meta_refresh_enabled = enabled - def _stop_scheduler(scheduler, thread): try: @@ -2543,7 +4269,7 @@ def _stop_scheduler(scheduler, thread): thread.join() -class _Scheduler(object): +class _Scheduler(Thread): _queue = None _scheduled_tasks = None @@ -2551,18 +4277,14 @@ class _Scheduler(object): is_shutdown = False def __init__(self, executor): - self._queue = Queue.PriorityQueue() + self._queue = queue.PriorityQueue() self._scheduled_tasks = set() self._count = count() self._executor = executor - t = Thread(target=self.run, name="Task Scheduler") - t.daemon = True - t.start() - - # although this runs on a daemonized thread, we prefer to stop - # it gracefully to avoid random errors during interpreter shutdown - atexit.register(partial(_stop_scheduler, weakref.proxy(self), t)) + Thread.__init__(self, name="Task Scheduler") + self.daemon = True + self.start() def shutdown(self): try: @@ -2572,6 +4294,7 @@ def shutdown(self): pass self.is_shutdown = True self._queue.put_nowait((0, 0, None)) + self.join() def schedule(self, delay, fn, *args, **kwargs): self._insert_task(delay, (fn, args, tuple(kwargs.items()))) @@ -2600,7 +4323,8 @@ def run(self): while True: run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: - log.debug("Not executing scheduled task due to Scheduler shutdown") + if task: + log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): self._scheduled_tasks.discard(task) @@ -2611,7 +4335,7 @@ def run(self): else: self._queue.put_nowait((run_at, i, task)) break - except Queue.Empty: + except queue.Empty: pass time.sleep(0.1) @@ -2624,15 +4348,11 @@ def _log_if_failed(self, future): exc_info=exc) -def refresh_schema_and_set_result(control_conn, response_future, **kwargs): +def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: - if control_conn._meta_refresh_enabled: - log.debug("Refreshing schema in response to schema change. " - "%s", kwargs) - response_future.is_schema_agreed = control_conn._refresh_schema(response_future._connection, **kwargs) - else: - log.debug("Skipping schema refresh in response to schema change because meta refresh is disabled; " - "%s", kwargs) + log.debug("Refreshing schema in response to schema change. " + "%s", kwargs) + response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) @@ -2665,20 +4385,38 @@ class ResponseFuture(object): Always ``True`` for non-DDL requests. """ + request_encoded_size = None + """ + Size of the request message sent + """ + + coordinator_host = None + """ + The host from which we received a response + """ + + attempted_hosts = None + """ + A list of hosts tried, including all speculative executions, retries, and pages + """ + session = None row_factory = None message = None default_timeout = None + _retry_policy = None + _profile_manager = None + _req_id = None _final_result = _NOT_SET _col_names = None + _col_types = None _final_exception = None _query_traces = None _callbacks = None _errbacks = None _current_host = None - _current_pool = None _connection = None _query_retries = 0 _start_time = None @@ -2688,45 +4426,149 @@ class ResponseFuture(object): _warnings = None _timer = None _protocol_handler = ProtocolHandler + _spec_execution_plan = NoSpeculativeExecutionPlan() + _continuous_paging_options = None + _continuous_paging_session = None + _host = None _warned_timeout = False - def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None): + def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, + retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, + speculative_execution_plan=None, continuous_paging_state=None, host=None): self.session = session - self.row_factory = session.row_factory + # TODO: normalize handling of retry policy and row factory + self.row_factory = row_factory or session.row_factory + self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy self.message = message self.query = query self.timeout = timeout + self._retry_policy = retry_policy self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() - if metrics is not None: - self._start_time = time.time() + self._start_time = start_time or time.time() + self._host = host + self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() self._errors = {} self._callbacks = [] self._errbacks = [] + self.attempted_hosts = [] + self._start_timer() + self._continuous_paging_state = continuous_paging_state + + @property + def _time_remaining(self): + if self.timeout is None: + return None + return (self._start_time + self.timeout) - time.time() def _start_timer(self): - if self.timeout is not None: - self._timer = self.session.cluster.connection_class.create_timer(self.timeout, self._on_timeout) + if self._timer is None: + spec_delay = self._spec_execution_plan.next_execution(self._current_host) + if spec_delay >= 0: + if self._time_remaining is None or self._time_remaining > spec_delay: + self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) + return + if self._time_remaining is not None: + self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) def _cancel_timer(self): if self._timer: self._timer.cancel() - def _on_timeout(self): - self._set_final_exception(OperationTimedOut(self._errors, self._current_host)) + def _on_timeout(self, _attempts=0): + """ + Called when the request associated with this ResponseFuture times out. + + This function may reschedule itself. The ``_attempts`` parameter tracks + the number of times this has happened. This parameter should only be + set in those cases, where ``_on_timeout`` reschedules itself. + """ + # PYTHON-853: for short timeouts, we sometimes race with our __init__ + if self._connection is None and _attempts < 3: + self._timer = self.session.cluster.connection_class.create_timer( + 0.01, + partial(self._on_timeout, _attempts=_attempts + 1) + ) + return + + if self._connection is not None: + try: + self._connection._requests.pop(self._req_id) + # PYTHON-1044 + # This request might have been removed from the connection after the latter was defunct by heartbeat. + # We should still raise OperationTimedOut to reject the future so that the main event thread will not + # wait for it endlessly + except KeyError: + key = "Connection defunct by heartbeat" + errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + self._set_final_exception(OperationTimedOut(errors, self._current_host)) + return + + pool = self.session._pools.get(self._current_host) + if pool and not pool.is_shutdown: + # Do not return the stream ID to the pool yet. We cannot reuse it + # because the node might still be processing the query and will + # return a late response to that query - if we used such stream + # before the response to the previous query has arrived, the new + # query could get a response from the old query + with self._connection.lock: + self._connection.orphaned_request_ids.add(self._req_id) + if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + self._connection.orphaned_threshold_reached = True + + pool.return_connection(self._connection, stream_was_orphaned=True) + + errors = self._errors + if not errors: + if self.is_schema_agreed: + key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' + errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} + else: + connection = self.session.cluster.control_connection._connection + host = str(connection.endpoint) if connection else 'unknown' + errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} + + self._set_final_exception(OperationTimedOut(errors, self._current_host)) + + def _on_speculative_execute(self): + self._timer = None + if not self._event.is_set(): + + # PYTHON-836, the speculative queries must be after + # the query is sent from the main thread, otherwise the + # query from the main thread may raise NoHostAvailable + # if the _query_plan has been exhausted by the speculative queries. + # This also prevents a race condition accessing the iterator. + # We reschedule this call until the main thread has succeeded + # making a query + if not self.attempted_hosts: + self._timer = self.session.cluster.connection_class.create_timer(0.01, self._on_speculative_execute) + return + + if self._time_remaining is not None: + if self._time_remaining <= 0: + self._on_timeout() + return + self.send_request(error_no_hosts=False) + self._start_timer() def _make_query_plan(self): - # convert the list/generator/etc to an iterator so that subsequent - # calls to send_request (which retries may do) will resume where - # they last left off - self.query_plan = iter(self.session._load_balancer.make_query_plan( - self.session.keyspace, self.query)) + # set the query_plan according to the load balancing policy, + # or to the explicit host target if set + if self._host: + # returning a single value effectively disables retries + self.query_plan = [self._host] + else: + # convert the list/generator/etc to an iterator so that subsequent + # calls to send_request (which retries may do) will resume where + # they last left off + self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) - def send_request(self): + def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times @@ -2734,24 +4576,19 @@ def send_request(self): req_id = self._query(host) if req_id is not None: self._req_id = req_id - - # timer is only started here, after we have at least one message queued - # this is done to avoid overrun of timers with unfettered client requests - # in the case of full disconnect, where no hosts will be available - if self._timer is None: - self._start_timer() - return - - self._set_final_exception(NoHostAvailable( - "Unable to complete the operation against any hosts", self._errors)) + return True + if self.timeout is not None and time.time() - self._start_time > self.timeout: + self._on_timeout() + return True + if error_no_hosts: + self._set_final_exception(NoHostAvailable( + "Unable to complete the operation against any hosts", self._errors)) + return False def _query(self, host, message=None, cb=None): if message is None: message = self.message - if cb is None: - cb = self._set_result - pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -2761,19 +4598,29 @@ def _query(self, host, message=None, cb=None): return None self._current_host = host - self._current_pool = pool connection = None try: # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection - connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message) + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + + if cb is None: + cb = partial(self._set_result, host, connection, pool) + + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) + self.attempted_hosts.append(host) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) self._errors[host] = exc - return None + except ConnectionBusy as exc: + log.debug("Connection for host %s is busy, moving to the next host", host) + self._errors[host] = exc except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc @@ -2781,7 +4628,8 @@ def _query(self, host, message=None, cb=None): self._metrics.on_connection_error() if connection: pool.return_connection(connection) - return None + + return None @property def has_more_pages(self): @@ -2805,11 +4653,11 @@ def warnings(self): Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). - Otherwise it may throw if the response has not been received. + Otherwise, it may throw if the response has not been received. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise Exception("warnings cannot be retrieved before ResponseFuture is finalized") + raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") return self._warnings @property @@ -2821,13 +4669,13 @@ def custom_payload(self): Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). - Otherwise it may throw if the response has not been received. + Otherwise, it may throw if the response has not been received. :return: :ref:`custom_payload`. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): - raise Exception("custom_payload cannot be retrieved before ResponseFuture is finalized") + raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") return self._custom_payload def start_fetching_next_page(self): @@ -2848,20 +4696,21 @@ def start_fetching_next_page(self): self._event.clear() self._final_result = _NOT_SET self._final_exception = None - self._timer = None # clear cancelled timer; new one will be set when request is queued + self._start_timer() self.send_request() - def _reprepare(self, prepare_message): - cb = partial(self.session.submit, self._execute_after_prepare) - request_id = self._query(self._current_host, prepare_message, cb=cb) + def _reprepare(self, prepare_message, host, connection, pool): + cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) + request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() - def _set_result(self, response): + def _set_result(self, host, connection, pool, response): try: - if self._current_pool and self._connection: - self._current_pool.return_connection(self._connection) + self.coordinator_host = host + if pool: + pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: @@ -2884,7 +4733,7 @@ def _set_result(self, response): # event loop thread. if session: session._set_keyspace_for_all_pools( - response.results, self._set_keyspace_completed) + response.new_keyspace, self._set_keyspace_completed) elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread @@ -2892,20 +4741,21 @@ def _set_result(self, response): self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, - self, **response.results) + self, connection, **response.schema_change_event) + elif response.kind == RESULT_KIND_ROWS: + self._paging_state = response.paging_state + self._col_names = response.column_names + self._col_types = response.column_types + if getattr(self.message, 'continuous_paging_options', None): + self._handle_continuous_paging_first_response(connection, response) + else: + self._set_final_result(self.row_factory(response.column_names, response.parsed_rows)) + elif response.kind == RESULT_KIND_VOID: + self._set_final_result(None) else: - results = getattr(response, 'results', None) - if results is not None and response.kind == RESULT_KIND_ROWS: - self._paging_state = response.paging_state - self._col_names = results[0] - results = self.row_factory(*results) - self._set_final_result(results) + self._set_final_result(response) elif isinstance(response, ErrorMessage): - retry_policy = None - if self.query: - retry_policy = self.query.retry_policy - if not retry_policy: - retry_policy = self.session.cluster.default_retry_policy + retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): if self._metrics is not None: @@ -2922,20 +4772,16 @@ def _set_result(self, response): self._metrics.on_unavailable() retry = retry_policy.on_unavailable( self.query, retry_num=self._query_retries, **response.info) - elif isinstance(response, OverloadedErrorMessage): + elif isinstance(response, (OverloadedErrorMessage, + IsBootstrappingErrorMessage, + TruncateError, ServerError)): + log.warning("Host %s error: %s.", host, response.summary) if self._metrics is not None: self._metrics.on_other_error() - # need to retry against a different host here - log.warning("Host %s is overloaded, retrying against a different " - "host", self._current_host) - self._retry(reuse_connection=False, consistency_level=None) - return - elif isinstance(response, IsBootstrappingErrorMessage): - if self._metrics is not None: - self._metrics.on_other_error() - # need to retry against a different host here - self._retry(reuse_connection=False, consistency_level=None) - return + cl = getattr(self.message, 'consistency_level', None) + retry = retry_policy.on_request_error( + self.query, cl, error=response, + retry_num=self._query_retries) elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: query_id = self.prepared_statement.query_id @@ -2959,7 +4805,8 @@ def _set_result(self, response): current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace - if prepared_keyspace and current_keyspace != prepared_keyspace: + if not ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) \ + and prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( ValueError("The Session's current keyspace (%s) does " "not match the keyspace the statement was " @@ -2968,11 +4815,14 @@ def _set_result(self, response): return log.debug("Re-preparing unrecognized prepared statement against host %s: %s", - self._current_host, prepared_statement.query_string) - prepare_message = PrepareMessage(query=prepared_statement.query_string) + host, prepared_statement.query_string) + prepared_keyspace = prepared_statement.keyspace \ + if ProtocolVersion.uses_keyspace_flag(self.session.cluster.protocol_version) else None + prepare_message = PrepareMessage(query=prepared_statement.query_string, + keyspace=prepared_keyspace) # since this might block, run on the executor to avoid hanging # the event loop thread - self.session.submit(self._reprepare, prepare_message) + self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): @@ -2981,22 +4831,16 @@ def _set_result(self, response): self._set_final_exception(response) return - retry_type, consistency = retry - if retry_type is RetryPolicy.RETRY: - self._query_retries += 1 - self._retry(reuse_connection=True, consistency_level=consistency) - elif retry_type is RetryPolicy.RETHROW: - self._set_final_exception(response.to_exception()) - else: # IGNORE - if self._metrics is not None: - self._metrics.on_ignore() - self._set_final_result(None) + self._handle_retry_decision(retry, response, host) elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) - self._retry(reuse_connection=False, consistency_level=None) + cl = getattr(self.message, 'consistency_level', None) + retry = self._retry_policy.on_request_error( + self.query, cl, error=response, retry_num=self._query_retries) + self._handle_retry_decision(retry, response, host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) @@ -3005,7 +4849,8 @@ def _set_result(self, response): else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) - exc = ConnectionException(msg, self._current_host) + exc = ConnectionException(msg, host) + self._cancel_timer() self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: @@ -3013,6 +4858,14 @@ def _set_result(self, response): log.exception("Unexpected exception while handling result in ResponseFuture:") self._set_final_exception(exc) + def _handle_continuous_paging_first_response(self, connection, response): + self._continuous_paging_session = connection.new_continuous_paging_session(response.stream_id, + self._protocol_handler.decode_message, + self.row_factory, + self._continuous_paging_state) + self._continuous_paging_session.on_message(response) + self._set_final_result(self._continuous_paging_session.results()) + def _set_keyspace_completed(self, errors): if not errors: self._set_final_result(None) @@ -3020,29 +4873,44 @@ def _set_keyspace_completed(self, errors): self._set_final_exception(ConnectionException( "Failed to set keyspace on all hosts: %s" % (errors,))) - def _execute_after_prepare(self, response): + def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ - if self._current_pool and self._connection: - self._current_pool.return_connection(self._connection) + if pool: + pool.return_connection(connection) if self._final_exception: return if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_PREPARED: + if self.prepared_statement: + if self.prepared_statement.query_id != response.query_id: + self._set_final_exception(DriverException( + "ID mismatch while trying to reprepare (expected {expected}, got {got}). " + "This prepared statement won't work anymore. " + "This usually happens when you run a 'USE...' " + "query after the statement was prepared.".format( + expected=hexlify(self.prepared_statement.query_id), got=hexlify(response.query_id) + ) + )) + self.prepared_statement.result_metadata = response.column_metadata + new_metadata_id = response.result_metadata_id + if new_metadata_id is not None: + self.prepared_statement.result_metadata_id = new_metadata_id + # use self._query to re-use the same host and # at the same time properly borrow the connection - request_id = self._query(self._current_host) + request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " - "on host %s: %s" % (self._current_host, response))) + "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) @@ -3050,14 +4918,14 @@ def _execute_after_prepare(self, response): self._set_final_exception(response) elif isinstance(response, ConnectionException): log.debug("Connection error when preparing statement on host %s: %s", - self._current_host, response) + host, response) # try again on a different host, preparing again if necessary - self._errors[self._current_host] = response + self._errors[host] = response self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " - "statement on host %s: %s" % (self._current_host, response))) + "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): self._cancel_timer() @@ -3066,13 +4934,20 @@ def _set_final_result(self, response): with self._callback_lock: self._final_result = response + # save off current callbacks inside lock for execution outside it + # -- prevents case where _final_result is set, then a callback is + # added and executed on the spot, then executed again as a + # registered callback + to_call = tuple( + partial(fn, response, *args, **kwargs) + for (fn, args, kwargs) in self._callbacks + ) self._event.set() # apply each callback - for callback in self._callbacks: - fn, args, kwargs = callback - fn(response, *args, **kwargs) + for callback_partial in to_call: + callback_partial() def _set_final_exception(self, response): self._cancel_timer() @@ -3081,13 +4956,43 @@ def _set_final_exception(self, response): with self._callback_lock: self._final_exception = response + # save off current errbacks inside lock for execution outside it -- + # prevents case where _final_exception is set, then an errback is + # added and executed on the spot, then executed again as a + # registered errback + to_call = tuple( + partial(fn, response, *args, **kwargs) + for (fn, args, kwargs) in self._errbacks + ) self._event.set() - for errback in self._errbacks: - fn, args, kwargs = errback - fn(response, *args, **kwargs) + # apply each callback + for callback_partial in to_call: + callback_partial() + + def _handle_retry_decision(self, retry_decision, response, host): - def _retry(self, reuse_connection, consistency_level): + def exception_from_response(response): + if hasattr(response, 'to_exception'): + return response.to_exception() + else: + return response + + retry_type, consistency = retry_decision + if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): + self._query_retries += 1 + reuse = retry_type == RetryPolicy.RETRY + self._retry(reuse, consistency, host) + elif retry_type is RetryPolicy.RETHROW: + self._set_final_exception(exception_from_response(response)) + else: # IGNORE + if self._metrics is not None: + self._metrics.on_ignore() + self._set_final_result(None) + + self._errors[host] = exception_from_response(response) + + def _retry(self, reuse_connection, consistency_level, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation @@ -3099,15 +5004,15 @@ def _retry(self, reuse_connection, consistency_level): self.message.consistency_level = consistency_level # don't retry on the event loop thread - self.session.submit(self._retry_task, reuse_connection) + self.session.submit(self._retry_task, reuse_connection, host) - def _retry_task(self, reuse_connection): + def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return - if reuse_connection and self._query(self._current_host) is not None: + if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host @@ -3150,32 +5055,41 @@ def get_query_trace_ids(self): """ return [trace.trace_id for trace in self._query_traces] - def get_query_trace(self, max_wait=None): + def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query trace of the last response, or `None` if tracing was not enabled. Note that this may raise an exception if there are problems retrieving the trace - details from Cassandra. If the trace is not available after `max_wait_sec`, + details from Cassandra. If the trace is not available after `max_wait`, + :exc:`cassandra.query.TraceUnavailable` will be raised. + + If the ResponseFuture is not done (async execution) and you try to retrieve the trace, :exc:`cassandra.query.TraceUnavailable` will be raised. + + `query_cl` is the consistency level used to poll the trace tables. """ + if self._final_result is _NOT_SET and self._final_exception is None: + raise TraceUnavailable( + "Trace information was not available. The ResponseFuture is not done.") + if self._query_traces: - return self._get_query_trace(len(self._query_traces) - 1, max_wait) + return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) - def get_all_query_traces(self, max_wait_per=None): + def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: - return [self._get_query_trace(i, max_wait_per) for i in range(len(self._query_traces))] + return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] return [] - def _get_query_trace(self, i, max_wait): + def _get_query_trace(self, i, max_wait, query_cl): trace = self._query_traces[i] if not trace.events: - trace.populate(max_wait=max_wait) + trace.populate(max_wait=max_wait, query_cl=query_cl) return trace def add_callback(self, fn, *args, **kwargs): @@ -3217,10 +5131,12 @@ def add_callback(self, fn, *args, **kwargs): """ run_now = False with self._callback_lock: + # Always add fn to self._callbacks, even when we're about to + # execute it, to prevent races with functions like + # start_fetching_next_page that reset _final_result + self._callbacks.append((fn, args, kwargs)) if self._final_result is not _NOT_SET: run_now = True - else: - self._callbacks.append((fn, args, kwargs)) if run_now: fn(self._final_result, *args, **kwargs) return self @@ -3233,10 +5149,12 @@ def add_errback(self, fn, *args, **kwargs): """ run_now = False with self._callback_lock: + # Always add fn to self._errbacks, even when we're about to execute + # it, to prevent races with functions like start_fetching_next_page + # that reset _final_exception + self._errbacks.append((fn, args, kwargs)) if self._final_exception: run_now = True - else: - self._errbacks.append((fn, args, kwargs)) if run_now: fn(self._final_exception, *args, **kwargs) return self @@ -3271,13 +5189,13 @@ def add_callbacks(self, callback, errback, def clear_callbacks(self): with self._callback_lock: - self._callback = [] - self._errback = [] + self._callbacks = [] + self._errbacks = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result - return "" \ - % (self.query, self._req_id, result, self._final_exception, self._current_host) + return "" \ + % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) __repr__ = __str__ @@ -3316,6 +5234,7 @@ class ResultSet(object): def __init__(self, response_future, initial_response): self.response_future = response_future self.column_names = response_future._col_names + self.column_types = response_future._col_types self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @@ -3335,6 +5254,31 @@ def current_rows(self): """ return self._current_rows or [] + def all(self): + """ + Returns all the remaining rows as a list. This is basically + a convenient shortcut to `list(result_set)`. + + This function is not recommended for queries that return a large number of elements. + """ + return list(self) + + def one(self): + """ + Return a single row of the results or None if empty. This is basically + a shortcut to `result_set.current_rows[0]` and should only be used when + you know a query returns a single row. Consider using an iterator if the + ResultSet contains more than one row. + """ + row = None + if self._current_rows: + try: + row = self._current_rows[0] + except TypeError: # generator object is not subscriptable, PYTHON-1026 + row = next(iter(self._current_rows)) + + return row + def __iter__(self): if self._list_mode: return iter(self._current_rows) @@ -3350,8 +5294,15 @@ def next(self): self._current_rows = [] raise - self.fetch_next_page() - self._page_iter = iter(self._current_rows) + if not self.response_future._continuous_paging_session: + self.fetch_next_page() + self._page_iter = iter(self._current_rows) + + # Some servers can return empty pages in this case; Scylla is known to do + # so in some circumstances. Guard against this by recursing to handle + # the next(iter) call. If we have an empty page in that case it will + # get handled by the StopIteration handler when we recurse. + return self.next() return next(self._page_iter) @@ -3399,6 +5350,9 @@ def __eq__(self, other): return self._current_rows == other def __getitem__(self, i): + if i == 0: + warn("ResultSet indexing support will be removed in 4.0. Consider using " + "ResultSet.one() to get a single row.", DeprecationWarning) self._enter_list_mode("index operator") return self._current_rows[i] @@ -3421,18 +5375,31 @@ def get_all_query_traces(self, max_wait_sec_per=None): """ return self.response_future.get_all_query_traces(max_wait_sec_per) + def cancel_continuous_paging(self): + try: + self.response_future._continuous_paging_session.cancel() + except AttributeError: + raise DriverException("Attempted to cancel paging with no active session. This is only for requests with ContinuousPagingOptions.") + @property def was_applied(self): """ For LWT results, returns whether the transaction was applied. - Result is indeterminate if called on a result that was not an LWT request. + Result is indeterminate if called on a result that was not an LWT request or on + a :class:`.query.BatchStatement` containing LWT. In the latter case either all the batch + succeeds or fails. - Only valid when one of tne of the internal row factories is in use. + Only valid when one of the internal row factories is in use. """ if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): - raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factsory,)) - if len(self.current_rows) != 1: + raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factory,)) + + is_batch_statement = isinstance(self.response_future.query, BatchStatement) + if is_batch_statement and (not self.column_names or self.column_names[0] != "[applied]"): + raise RuntimeError("No LWT were present in the BatchStatement") + + if not is_batch_statement and len(self.current_rows) != 1: raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) row = self.current_rows[0] @@ -3441,3 +5408,12 @@ def was_applied(self): else: return row['[applied]'] + @property + def paging_state(self): + """ + Server paging state of the query. Can be `None` if the query was not paged. + + The driver treats paging state as opaque, but it may contain primary key data, so applications may want to + avoid sending this to untrusted parties. + """ + return self.response_future._paging_state diff --git a/cassandra/cmurmur3.c b/cassandra/cmurmur3.c index 2f4cfa0fd6..4affdad46c 100644 --- a/cassandra/cmurmur3.c +++ b/cassandra/cmurmur3.c @@ -6,7 +6,7 @@ * * Copyright (c) 2011 Austin Appleby (Murmur3 routine) * Copyright (c) 2011 Patrick Hensley (Python wrapper, packaging) - * Copyright 2013-2016 DataStax (Minor modifications to match Cassandra's MM3 hashes) + * Copyright DataStax (Minor modifications to match Cassandra's MM3 hashes) * */ @@ -14,12 +14,6 @@ #include #include -#if PY_VERSION_HEX < 0x02050000 -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -#endif - #ifdef PYPY_VERSION #define COMPILING_IN_PYPY 1 #define COMPILING_IN_CPYTHON 0 @@ -216,8 +210,6 @@ static PyMethodDef cmurmur3_methods[] = { {NULL, NULL, 0, NULL} }; -#if PY_MAJOR_VERSION >= 3 - static int cmurmur3_traverse(PyObject *m, visitproc visit, void *arg) { Py_VISIT(GETSTATE(m)->error); return 0; @@ -245,18 +237,8 @@ static struct PyModuleDef moduledef = { PyObject * PyInit_cmurmur3(void) -#else -#define INITERROR return - -void -initcmurmur3(void) -#endif { -#if PY_MAJOR_VERSION >= 3 PyObject *module = PyModule_Create(&moduledef); -#else - PyObject *module = Py_InitModule("cmurmur3", cmurmur3_methods); -#endif struct module_state *st = NULL; if (module == NULL) @@ -269,7 +251,5 @@ initcmurmur3(void) INITERROR; } -#if PY_MAJOR_VERSION >= 3 return module; -#endif } diff --git a/cassandra/column_encryption/_policies.py b/cassandra/column_encryption/_policies.py new file mode 100644 index 0000000000..e1519f6b79 --- /dev/null +++ b/cassandra/column_encryption/_policies.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from functools import lru_cache + +import logging +import os + +log = logging.getLogger(__name__) + +from cassandra.cqltypes import _cqltypes +from cassandra.policies import ColumnEncryptionPolicy + +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +AES256_BLOCK_SIZE = 128 +AES256_BLOCK_SIZE_BYTES = int(AES256_BLOCK_SIZE / 8) +AES256_KEY_SIZE = 256 +AES256_KEY_SIZE_BYTES = int(AES256_KEY_SIZE / 8) + +ColData = namedtuple('ColData', ['key','type']) + +class AES256ColumnEncryptionPolicy(ColumnEncryptionPolicy): + + # Fix block cipher mode for now. IV size is a function of block cipher used + # so fixing this avoids (possibly unnecessary) validation logic here. + mode = modes.CBC + + # "iv" param here expects a bytearray that's the same size as the block + # size for AES-256 (128 bits or 16 bytes). If none is provided a new one + # will be randomly generated, but in this case the IV should be recorded and + # preserved or else you will not be able to decrypt any data encrypted by this + # policy. + def __init__(self, iv=None): + + # CBC uses an IV that's the same size as the block size + # + # Avoid defining IV with a default arg in order to stay away from + # any issues around the caching of default args + self.iv = iv + if self.iv: + if not len(self.iv) == AES256_BLOCK_SIZE_BYTES: + raise ValueError("This policy uses AES-256 with CBC mode and therefore expects a 128-bit initialization vector") + else: + self.iv = os.urandom(AES256_BLOCK_SIZE_BYTES) + + # ColData for a given ColDesc is always preserved. We only create a Cipher + # when there's an actual need to for a given ColDesc + self.coldata = {} + self.ciphers = {} + + def encrypt(self, coldesc, obj_bytes): + + # AES256 has a 128-bit block size so if the input bytes don't align perfectly on + # those blocks we have to pad them. There's plenty of room for optimization here: + # + # * Instances of the PKCS7 padder should be managed in a bounded pool + # * It would be nice if we could get a flag from encrypted data to indicate + # whether it was padded or not + # * Might be able to make this happen with a leading block of flags in encrypted data + padder = padding.PKCS7(AES256_BLOCK_SIZE).padder() + padded_bytes = padder.update(obj_bytes) + padder.finalize() + + cipher = self._get_cipher(coldesc) + encryptor = cipher.encryptor() + return self.iv + encryptor.update(padded_bytes) + encryptor.finalize() + + def decrypt(self, coldesc, bytes): + + iv = bytes[:AES256_BLOCK_SIZE_BYTES] + encrypted_bytes = bytes[AES256_BLOCK_SIZE_BYTES:] + cipher = self._get_cipher(coldesc, iv=iv) + decryptor = cipher.decryptor() + padded_bytes = decryptor.update(encrypted_bytes) + decryptor.finalize() + + unpadder = padding.PKCS7(AES256_BLOCK_SIZE).unpadder() + return unpadder.update(padded_bytes) + unpadder.finalize() + + def add_column(self, coldesc, key, type): + + if not coldesc: + raise ValueError("ColDesc supplied to add_column cannot be None") + if not key: + raise ValueError("Key supplied to add_column cannot be None") + if not type: + raise ValueError("Type supplied to add_column cannot be None") + if type not in _cqltypes.keys(): + raise ValueError("Type %s is not a supported type".format(type)) + if not len(key) == AES256_KEY_SIZE_BYTES: + raise ValueError("AES256 column encryption policy expects a 256-bit encryption key") + self.coldata[coldesc] = ColData(key, _cqltypes[type]) + + def contains_column(self, coldesc): + return coldesc in self.coldata + + def encode_and_encrypt(self, coldesc, obj): + if not coldesc: + raise ValueError("ColDesc supplied to encode_and_encrypt cannot be None") + if not obj: + raise ValueError("Object supplied to encode_and_encrypt cannot be None") + coldata = self.coldata.get(coldesc) + if not coldata: + raise ValueError("Could not find ColData for ColDesc %s".format(coldesc)) + return self.encrypt(coldesc, coldata.type.serialize(obj, None)) + + def cache_info(self): + return AES256ColumnEncryptionPolicy._build_cipher.cache_info() + + def column_type(self, coldesc): + return self.coldata[coldesc].type + + def _get_cipher(self, coldesc, iv=None): + """ + Access relevant state from this instance necessary to create a Cipher and then get one, + hopefully returning a cached instance if we've already done so (and it hasn't been evicted) + """ + try: + coldata = self.coldata[coldesc] + return AES256ColumnEncryptionPolicy._build_cipher(coldata.key, iv or self.iv) + except KeyError: + raise ValueError("Could not find column {}".format(coldesc)) + + # Explicitly use a class method here to avoid caching self + @lru_cache(maxsize=128) + def _build_cipher(key, iv): + return Cipher(algorithms.AES256(key), AES256ColumnEncryptionPolicy.mode(iv)) diff --git a/cassandra/column_encryption/policies.py b/cassandra/column_encryption/policies.py new file mode 100644 index 0000000000..a1bd25d3e6 --- /dev/null +++ b/cassandra/column_encryption/policies.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import cryptography + from cassandra.column_encryption._policies import * +except ImportError: + # Cryptography is not installed + pass diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index 75c2604db8..012f52f954 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,70 +15,23 @@ # limitations under the License. +import logging +from collections import namedtuple +from concurrent.futures import Future from heapq import heappush, heappop from itertools import cycle -import six -from six.moves import xrange, zip from threading import Condition -import sys -from cassandra.cluster import ResultSet +from cassandra.cluster import ResultSet, EXEC_PROFILE_DEFAULT -import logging log = logging.getLogger(__name__) -def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False): - """ - Executes a sequence of (statement, parameters) tuples concurrently. Each - ``parameters`` item must be a sequence or :const:`None`. - - The `concurrency` parameter controls how many statements will be executed - concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2, - it is recommended that this be kept below 100 times the number of - core connections per host times the number of connected hosts (see - :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded, - the event loop thread may attempt to block on new connection creation, - substantially impacting throughput. If :attr:`~.Cluster.protocol_version` - is 3 or higher, you can safely experiment with higher levels of concurrency. - - If `raise_on_first_error` is left as :const:`True`, execution will stop - after the first failed statement and the corresponding exception will be - raised. - - `results_generator` controls how the results are returned. - If :const:`False`, the results are returned only after all requests have completed. - - If :const:`True`, a generator expression is returned. Using a generator results in a constrained - memory footprint when the results set will be large -- results are yielded - as they return instead of materializing the entire list at once. The trade for lower memory - footprint is marginal CPU overhead (more thread coordination and sorting out-of-order results - on-the-fly). - - A sequence of ``(success, result_or_exc)`` tuples is returned in the same - order that the statements were passed in. If ``success`` is :const:`False`, - there was an error executing the statement, and ``result_or_exc`` will be - an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc`` - will be the query result. - - Example usage:: - - select_statement = session.prepare("SELECT * FROM users WHERE id=?") - - statements_and_params = [] - for user_id in user_ids: - params = (user_id, ) - statements_and_params.append((select_statement, params)) - - results = execute_concurrent( - session, statements_and_params, raise_on_first_error=False) - - for (success, result) in results: - if not success: - handle_error(result) # result will be an Exception - else: - process_user(result[0]) # result will be a list of rows +ExecutionResult = namedtuple('ExecutionResult', ['success', 'result_or_exc']) +def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False, execution_profile=EXEC_PROFILE_DEFAULT): + """ + See :meth:`.Session.execute_concurrent`. """ if concurrency <= 0: raise ValueError("concurrency must be greater than 0") @@ -84,20 +39,25 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais if not statements_and_parameters: return [] - executor = ConcurrentExecutorGenResults(session, statements_and_parameters) if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters) + executor = ConcurrentExecutorGenResults(session, statements_and_parameters, execution_profile) \ + if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters, execution_profile) return executor.execute(concurrency, raise_on_first_error) class _ConcurrentExecutor(object): - def __init__(self, session, statements_and_params): + max_error_recursion = 100 + + def __init__(self, session, statements_and_params, execution_profile): self.session = session self._enum_statements = enumerate(iter(statements_and_params)) + self._execution_profile = execution_profile self._condition = Condition() self._fail_fast = False self._results_queue = [] self._current = 0 self._exec_count = 0 + self._exec_depth = 0 def execute(self, concurrency, fail_fast): self._fail_fast = fail_fast @@ -105,7 +65,7 @@ def execute(self, concurrency, fail_fast): self._current = 0 self._exec_count = 0 with self._condition: - for n in xrange(concurrency): + for n in range(concurrency): if not self._execute_next(): break return self._results() @@ -121,17 +81,22 @@ def _execute_next(self): pass def _execute(self, idx, statement, params): + self._exec_depth += 1 try: - future = self.session.execute_async(statement, params, timeout=None) + future = self.session.execute_async(statement, params, execution_profile=self._execution_profile) args = (future, idx) future.add_callbacks( callback=self._on_success, callback_args=args, errback=self._on_error, errback_args=args) except Exception as exc: - # exc_info with fail_fast to preserve stack trace info when raising on the client thread - # (matches previous behavior -- not sure why we wouldn't want stack trace in the other case) - e = sys.exc_info() if self._fail_fast and six.PY2 else exc - self._put_result(e, idx, False) + # If we're not failing fast and all executions are raising, there is a chance of recursing + # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry + # and let the event loop thread return. + if self._exec_depth < self.max_error_recursion: + self._put_result(exc, idx, False) + else: + self.session.submit(self._put_result, exc, idx, False) + self._exec_depth -= 1 def _on_success(self, result, future, idx): future.clear_callbacks() @@ -140,20 +105,12 @@ def _on_success(self, result, future, idx): def _on_error(self, result, future, idx): self._put_result(result, idx, False) - @staticmethod - def _raise(exc): - if six.PY2 and isinstance(exc, tuple): - (exc_type, value, traceback) = exc - six.reraise(exc_type, value, traceback) - else: - raise exc - class ConcurrentExecutorGenResults(_ConcurrentExecutor): def _put_result(self, result, idx, success): with self._condition: - heappush(self._results_queue, (idx, (success, result))) + heappush(self._results_queue, (idx, ExecutionResult(success, result))) self._execute_next() self._condition.notify() @@ -167,7 +124,7 @@ def _results(self): try: self._condition.release() if self._fail_fast and not res[0]: - self._raise(res[1]) + raise res[1] yield res finally: self._condition.acquire() @@ -183,7 +140,7 @@ def execute(self, concurrency, fail_fast): return super(ConcurrentExecutorListResults, self).execute(concurrency, fail_fast) def _put_result(self, result, idx, success): - self._results_queue.append((idx, (success, result))) + self._results_queue.append((idx, ExecutionResult(success, result))) with self._condition: self._current += 1 if not success and self._fail_fast: @@ -198,23 +155,59 @@ def _results(self): while self._current < self._exec_count: self._condition.wait() if self._exception and self._fail_fast: - self._raise(self._exception) + raise self._exception if self._exception and self._fail_fast: # raise the exception even if there was no wait - self._raise(self._exception) + raise self._exception return [r[1] for r in sorted(self._results_queue)] def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs): """ - Like :meth:`~cassandra.concurrent.execute_concurrent()`, but takes a single - statement and a sequence of parameters. Each item in ``parameters`` - should be a sequence or :const:`None`. + See :meth:`.Session.execute_concurrent_with_args`. + """ + return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) - Example usage:: - statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)") - parameters = [(x,) for x in range(1000)] - execute_concurrent_with_args(session, statement, parameters, concurrency=50) +class ConcurrentExecutorFutureResults(ConcurrentExecutorListResults): + def __init__(self, session, statements_and_params, execution_profile, future): + super().__init__(session, statements_and_params, execution_profile) + self.future = future + + def _put_result(self, result, idx, success): + super()._put_result(result, idx, success) + with self._condition: + if self._current == self._exec_count: + if self._exception and self._fail_fast: + self.future.set_exception(self._exception) + else: + sorted_results = [r[1] for r in sorted(self._results_queue)] + self.future.set_result(sorted_results) + + +def execute_concurrent_async( + session, + statements_and_parameters, + concurrency=100, + raise_on_first_error=False, + execution_profile=EXEC_PROFILE_DEFAULT +): """ - return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) + See :meth:`.Session.execute_concurrent_async`. + """ + # Create a Future object and initialize the custom ConcurrentExecutor with the Future + future = Future() + executor = ConcurrentExecutorFutureResults( + session=session, + statements_and_params=statements_and_parameters, + execution_profile=execution_profile, + future=future + ) + + # Execute concurrently + try: + executor.execute(concurrency=concurrency, fail_fast=raise_on_first_error) + except Exception as e: + future.set_exception(e) + + return future diff --git a/cassandra/connection.py b/cassandra/connection.py index 0624ad1364..4a16c46ab4 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,29 +17,25 @@ from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno -from functools import wraps, partial +from functools import wraps, partial, total_ordering from heapq import heappush, heappop import io import logging -import six -from six.moves import range import socket import struct import sys -from threading import Thread, Event, RLock +from threading import Thread, Event, RLock, Condition import time +import ssl +import weakref -try: - import ssl -except ImportError: - ssl = None # NOQA if 'gevent.monkey' in sys.modules: from gevent.queue import Queue, Empty else: - from six.moves.queue import Queue, Empty # noqa + from queue import Queue, Empty # noqa -from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut +from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut, ProtocolVersion from cassandra.marshal import int32_pack from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, StartupMessage, ErrorMessage, CredentialsMessage, @@ -45,12 +43,16 @@ InvalidRequestException, SupportedMessage, AuthResponseMessage, AuthChallengeMessage, AuthSuccessMessage, ProtocolException, - MAX_SUPPORTED_VERSION, RegisterMessage) + RegisterMessage, ReviseRequestMessage) +from cassandra.segment import SegmentCodec, CrcException from cassandra.util import OrderedDict log = logging.getLogger(__name__) +segment_codec_no_compression = SegmentCodec() +segment_codec_lz4 = None + # We use an ordered dictionary and specifically add lz4 before # snappy so that lz4 will be preferred. Changing the order of this # will change the compression preferences for the driver. @@ -61,6 +63,23 @@ except ImportError: pass else: + # The compress and decompress functions we need were moved from the lz4 to + # the lz4.block namespace, so we try both here. + try: + from lz4 import block as lz4_block + except ImportError: + lz4_block = lz4 + + try: + lz4_block.compress + lz4_block.decompress + except AttributeError: + raise ImportError( + 'lz4 not imported correctly. Imported object should have ' + '.compress and and .decompress attributes but does not. ' + 'Please file a bug report on JIRA. (Imported object was ' + '{lz4_block})'.format(lz4_block=repr(lz4_block)) + ) # Cassandra writes the uncompressed message length in big endian order, # but the lz4 lib requires little endian order, so we wrap these @@ -68,13 +87,14 @@ def lz4_compress(byts): # write length in big-endian instead of little-endian - return int32_pack(len(byts)) + lz4.compress(byts)[4:] + return int32_pack(len(byts)) + lz4_block.compress(byts)[4:] def lz4_decompress(byts): # flip from big-endian to little-endian - return lz4.decompress(byts[3::-1] + byts[4:]) + return lz4_block.decompress(byts[3::-1] + byts[4:]) locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) + segment_codec_lz4 = SegmentCodec(lz4_compress, lz4_decompress) try: import snappy @@ -88,6 +108,7 @@ def decompress(byts): return snappy.decompress(byts) locally_supported_compressions['snappy'] = (snappy.compress, decompress) +DRIVER_NAME, DRIVER_VERSION = 'DataStax Python Driver', sys.modules['cassandra'].__version__ PROTOCOL_VERSION_MASK = 0x7f @@ -99,6 +120,257 @@ def decompress(byts): frame_header_v3 = struct.Struct('>BhBi') +class EndPoint(object): + """ + Represents the information to connect to a cassandra node. + """ + + @property + def address(self): + """ + The IP address of the node. This is the RPC address the driver uses when connecting to the node + """ + raise NotImplementedError() + + @property + def port(self): + """ + The port of the node. + """ + raise NotImplementedError() + + @property + def ssl_options(self): + """ + SSL options specific to this endpoint. + """ + return None + + @property + def socket_family(self): + """ + The socket family of the endpoint. + """ + return socket.AF_UNSPEC + + def resolve(self): + """ + Resolve the endpoint to an address/port. This is called + only on socket connection. + """ + raise NotImplementedError() + + +class EndPointFactory(object): + + cluster = None + + def configure(self, cluster): + """ + This is called by the cluster during its initialization. + """ + self.cluster = cluster + return self + + def create(self, row): + """ + Create an EndPoint from a system.peers row. + """ + raise NotImplementedError() + + +@total_ordering +class DefaultEndPoint(EndPoint): + """ + Default EndPoint implementation, basically just an address and port. + """ + + def __init__(self, address, port=9042): + self._address = address + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, DefaultEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return hash((self.address, self.port)) + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + def __str__(self): + return str("%s:%d" % (self.address, self.port)) + + def __repr__(self): + return "<%s: %s:%d>" % (self.__class__.__name__, self.address, self.port) + + +class DefaultEndPointFactory(EndPointFactory): + + port = None + """ + If no port is discovered in the row, this is the default port + used for endpoint creation. + """ + + def __init__(self, port=None): + self.port = port + + def create(self, row): + # TODO next major... move this class so we don't need this kind of hack + from cassandra.metadata import _NodeInfo + addr = _NodeInfo.get_broadcast_rpc_address(row) + port = _NodeInfo.get_broadcast_rpc_port(row) + if port is None: + port = self.port if self.port else 9042 + + # create the endpoint with the translated address + # TODO next major, create a TranslatedEndPoint type + return DefaultEndPoint( + self.cluster.address_translator.translate(addr), + port) + + +@total_ordering +class SniEndPoint(EndPoint): + """SNI Proxy EndPoint implementation.""" + + def __init__(self, proxy_address, server_name, port=9042, init_index=0): + self._proxy_address = proxy_address + self._index = init_index + self._resolved_address = None # resolved address + self._port = port + self._server_name = server_name + self._ssl_options = {'server_hostname': server_name} + + @property + def address(self): + return self._proxy_address + + @property + def port(self): + return self._port + + @property + def ssl_options(self): + return self._ssl_options + + def resolve(self): + try: + resolved_addresses = self._resolve_proxy_addresses() + except socket.gaierror: + log.debug('Could not resolve sni proxy hostname "%s" ' + 'with port %d' % (self._proxy_address, self._port)) + raise + + # round-robin pick + self._resolved_address = sorted(addr[4][0] for addr in resolved_addresses)[self._index % len(resolved_addresses)] + self._index += 1 + + return self._resolved_address, self._port + + def _resolve_proxy_addresses(self): + return socket.getaddrinfo(self._proxy_address, self._port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + + def __eq__(self, other): + return (isinstance(other, SniEndPoint) and + self.address == other.address and self.port == other.port and + self._server_name == other._server_name) + + def __hash__(self): + return hash((self.address, self.port, self._server_name)) + + def __lt__(self, other): + return ((self.address, self.port, self._server_name) < + (other.address, other.port, self._server_name)) + + def __str__(self): + return str("%s:%d:%s" % (self.address, self.port, self._server_name)) + + def __repr__(self): + return "<%s: %s:%d:%s>" % (self.__class__.__name__, + self.address, self.port, self._server_name) + + +class SniEndPointFactory(EndPointFactory): + + def __init__(self, proxy_address, port): + self._proxy_address = proxy_address + self._port = port + # Initial lookup index to prevent all SNI endpoints to be resolved + # into the same starting IP address (which might not be available currently). + # If SNI resolves to 3 IPs, first endpoint will connect to first + # IP address, and subsequent resolutions to next IPs in round-robin + # fusion. + self._init_index = -1 + + def create(self, row): + host_id = row.get("host_id") + if host_id is None: + raise ValueError("No host_id to create the SniEndPoint") + + self._init_index += 1 + return SniEndPoint(self._proxy_address, str(host_id), self._port, self._init_index) + + def create_from_sni(self, sni): + self._init_index += 1 + return SniEndPoint(self._proxy_address, sni, self._port, self._init_index) + + +@total_ordering +class UnixSocketEndPoint(EndPoint): + """ + Unix Socket EndPoint implementation. + """ + + def __init__(self, unix_socket_path): + self._unix_socket_path = unix_socket_path + + @property + def address(self): + return self._unix_socket_path + + @property + def port(self): + return None + + @property + def socket_family(self): + return socket.AF_UNIX + + def resolve(self): + return self.address, None + + def __eq__(self, other): + return (isinstance(other, UnixSocketEndPoint) and + self._unix_socket_path == other._unix_socket_path) + + def __hash__(self): + return hash(self._unix_socket_path) + + def __lt__(self, other): + return self._unix_socket_path < other._unix_socket_path + + def __str__(self): + return str("%s" % (self._unix_socket_path,)) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version @@ -122,18 +394,22 @@ def __str__(self): return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) - NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) + class ConnectionException(Exception): """ An unrecoverable error was hit when attempting to use a connection, or the connection was already closed or defunct. """ - def __init__(self, message, host=None): + def __init__(self, message, endpoint=None): Exception.__init__(self, message) - self.host = host + self.endpoint = endpoint + + @property + def host(self): + return self.endpoint.address class ConnectionShutdown(ConnectionException): @@ -147,9 +423,9 @@ class ProtocolVersionUnsupported(ConnectionException): """ Server rejected startup message due to unsupported protocol version """ - def __init__(self, host, startup_version): - super(ProtocolVersionUnsupported, self).__init__("Unsupported protocol version on %s: %d", - (host, startup_version)) + def __init__(self, endpoint, startup_version): + msg = "Unsupported protocol version on %s: %d" % (endpoint, startup_version) + super(ProtocolVersionUnsupported, self).__init__(msg, endpoint) self.startup_version = startup_version @@ -168,6 +444,165 @@ class ProtocolError(Exception): pass +class CrcMismatchException(ConnectionException): + pass + + +class ContinuousPagingState(object): + """ + A class for specifying continuous paging state, only supported starting with DSE_V2. + """ + + num_pages_requested = None + """ + How many pages we have already requested + """ + + num_pages_received = None + """ + How many pages we have already received + """ + + max_queue_size = None + """ + The max queue size chosen by the user via the options + """ + + def __init__(self, max_queue_size): + self.num_pages_requested = max_queue_size # the initial query requests max_queue_size + self.num_pages_received = 0 + self.max_queue_size = max_queue_size + + +class ContinuousPagingSession(object): + def __init__(self, stream_id, decoder, row_factory, connection, state): + self.stream_id = stream_id + self.decoder = decoder + self.row_factory = row_factory + self.connection = connection + self._condition = Condition() + self._stop = False + self._page_queue = deque() + self._state = state + self.released = False + + def on_message(self, result): + if isinstance(result, ResultMessage): + self.on_page(result) + elif isinstance(result, ErrorMessage): + self.on_error(result) + + def on_page(self, result): + with self._condition: + if self._state: + self._state.num_pages_received += 1 + self._page_queue.appendleft((result.column_names, result.parsed_rows, None)) + self._stop |= result.continuous_paging_last + self._condition.notify() + + if result.continuous_paging_last: + self.released = True + + def on_error(self, error): + if isinstance(error, ErrorMessage): + error = error.to_exception() + + log.debug("Got error %s for session %s", error, self.stream_id) + + with self._condition: + self._page_queue.appendleft((None, None, error)) + self._stop = True + self._condition.notify() + + self.released = True + + def results(self): + try: + self._condition.acquire() + while True: + while not self._page_queue and not self._stop: + self._condition.wait(timeout=5) + while self._page_queue: + names, rows, err = self._page_queue.pop() + if err: + raise err + self.maybe_request_more() + self._condition.release() + for row in self.row_factory(names, rows): + yield row + self._condition.acquire() + if self._stop: + break + finally: + try: + self._condition.release() + except RuntimeError: + # This exception happens if the CP results are not entirely consumed + # and the session is terminated by the runtime. See PYTHON-1054 + pass + + def maybe_request_more(self): + if not self._state: + return + + max_queue_size = self._state.max_queue_size + num_in_flight = self._state.num_pages_requested - self._state.num_pages_received + space_in_queue = max_queue_size - len(self._page_queue) - num_in_flight + log.debug("Session %s from %s, space in CP queue: %s, requested: %s, received: %s, num_in_flight: %s", + self.stream_id, self.connection.host, space_in_queue, self._state.num_pages_requested, + self._state.num_pages_received, num_in_flight) + + if space_in_queue >= max_queue_size / 2: + self.update_next_pages(space_in_queue) + + def update_next_pages(self, num_next_pages): + try: + self._state.num_pages_requested += num_next_pages + log.debug("Updating backpressure for session %s from %s", self.stream_id, self.connection.host) + with self.connection.lock: + self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE, + self.stream_id, + next_pages=num_next_pages), + self.connection.get_request_id(), + self._on_backpressure_response) + except ConnectionShutdown as ex: + log.debug("Failed to update backpressure for session %s from %s, connection is shutdown", + self.stream_id, self.connection.host) + self.on_error(ex) + + def _on_backpressure_response(self, response): + if isinstance(response, ResultMessage): + log.debug("Paging session %s backpressure updated.", self.stream_id) + else: + log.error("Failed updating backpressure for session %s from %s: %s", self.stream_id, self.connection.host, + response.to_exception() if hasattr(response, 'to_exception') else response) + self.on_error(response) + + def cancel(self): + try: + log.debug("Canceling paging session %s from %s", self.stream_id, self.connection.host) + with self.connection.lock: + self.connection.send_msg(ReviseRequestMessage(ReviseRequestMessage.RevisionType.PAGING_CANCEL, + self.stream_id), + self.connection.get_request_id(), + self._on_cancel_response) + except ConnectionShutdown: + log.debug("Failed to cancel session %s from %s, connection is shutdown", + self.stream_id, self.connection.host) + + with self._condition: + self._stop = True + self._condition.notify() + + def _on_cancel_response(self, response): + if isinstance(response, ResultMessage): + log.debug("Paging session %s canceled.", self.stream_id) + else: + log.error("Failed canceling streaming session %s from %s: %s", self.stream_id, self.connection.host, + response.to_exception() if hasattr(response, 'to_exception') else response) + self.released = True + + def defunct_on_error(f): @wraps(f) @@ -181,11 +616,59 @@ def wrapper(self, *args, **kwargs): DEFAULT_CQL_VERSION = '3.0.0' -if six.PY3: - def int_from_buf_item(i): - return i -else: - int_from_buf_item = ord + +class _ConnectionIOBuffer(object): + """ + Abstraction class to ease the use of the different connection io buffers. With + protocol V5 and checksumming, the data is read, validated and copied to another + cql frame buffer. + """ + _io_buffer = None + _cql_frame_buffer = None + _connection = None + _segment_consumed = False + + def __init__(self, connection): + self._io_buffer = io.BytesIO() + self._connection = weakref.proxy(connection) + + @property + def io_buffer(self): + return self._io_buffer + + @property + def cql_frame_buffer(self): + return self._cql_frame_buffer if self.is_checksumming_enabled else \ + self._io_buffer + + def set_checksumming_buffer(self): + self.reset_io_buffer() + self._cql_frame_buffer = io.BytesIO() + + @property + def is_checksumming_enabled(self): + return self._connection._is_checksumming_enabled + + @property + def has_consumed_segment(self): + return self._segment_consumed; + + def readable_io_bytes(self): + return self.io_buffer.tell() + + def readable_cql_frame_bytes(self): + return self.cql_frame_buffer.tell() + + def reset_io_buffer(self): + self._io_buffer = io.BytesIO(self._io_buffer.read()) + self._io_buffer.seek(0, 2) # 2 == SEEK_END + + def reset_cql_frame_buffer(self): + if self.is_checksumming_enabled: + self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read()) + self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END + else: + self.reset_io_buffer() class Connection(object): @@ -196,20 +679,31 @@ class Connection(object): out_buffer_size = 4096 cql_version = None - protocol_version = MAX_SUPPORTED_VERSION + no_compact = False + protocol_version = ProtocolVersion.MAX_SUPPORTED keyspace = None compression = True + _compression_type = None compressor = None decompressor = None + endpoint = None ssl_options = None + ssl_context = None last_error = None # The current number of operations that are in flight. More precisely, # the number of request IDs that are currently in use. + # This includes orphaned requests. in_flight = 0 + # Max concurrent requests allowed per connection. This is set optimistically high, allowing + # all request ids to be used in protocol version 3+. Normally concurrency would be controlled + # at a higher level by the application or concurrent.execute_concurrent. This attribute + # is for lower-level integrations that want some upper bound without reimplementing. + max_in_flight = 2 ** 15 + # A set of available request IDs. When using the v3 protocol or higher, # this will not initially include all request IDs in order to save memory, # but the set will grow if it is exhausted. @@ -219,6 +713,20 @@ class Connection(object): # request_ids set highest_request_id = 0 + # Tracks the request IDs which are no longer waited on (timed out), but + # cannot be reused yet because the node might still send a response + # on this stream + orphaned_request_ids = None + + # Set to true if the orphaned stream ID count cross configured threshold + # and the connection will be replaced + orphaned_threshold_reached = False + + # If the number of orphaned streams reaches this threshold, this connection + # will become marked and will be replaced with a new connection by the + # owning pool (currently, only HostConnection supports this) + orphaned_threshold = 3 * max_in_flight // 4 + is_defunct = False is_closed = False lock = None @@ -231,24 +739,38 @@ class Connection(object): is_control_connection = False signaled_error = False # used for flagging at the pool level - _server_version = None + allow_beta_protocol_version = False - _iobuf = None _current_frame = None _socket = None _socket_impl = socket - _ssl_impl = ssl + + _check_hostname = False + _product_type = None + + _is_checksumming_enabled = False + + _on_orphaned_stream_released = None + + @property + def _iobuf(self): + # backward compatibility, to avoid any change in the reactors + return self._io_buffer.io_buffer def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, - cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, - user_type_map=None, connect_timeout=None): - self.host = host - self.port = port + cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, + user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, + ssl_context=None, on_orphaned_stream_released=None): + + # TODO next major rename host to endpoint and remove port kwarg. + self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) + self.authenticator = authenticator - self.ssl_options = ssl_options + self.ssl_options = ssl_options.copy() if ssl_options else {} + self.ssl_context = ssl_context self.sockopts = sockopts self.compression = compression self.cql_version = cql_version @@ -256,24 +778,55 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, self.is_control_connection = is_control_connection self.user_type_map = user_type_map self.connect_timeout = connect_timeout + self.allow_beta_protocol_version = allow_beta_protocol_version + self.no_compact = no_compact self._push_watchers = defaultdict(set) self._requests = {} - self._iobuf = io.BytesIO() + self._io_buffer = _ConnectionIOBuffer(self) + self._continuous_paging_sessions = {} + self._socket_writable = True + self.orphaned_request_ids = set() + self._on_orphaned_stream_released = on_orphaned_stream_released + + if ssl_options: + self.ssl_options.update(self.endpoint.ssl_options or {}) + elif self.endpoint.ssl_options: + self.ssl_options = self.endpoint.ssl_options + + # PYTHON-1331 + # + # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()... + # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if + # we need to do so. + # + # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this + # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call. + if not self.ssl_context and self.ssl_options: + self.ssl_context = self._build_ssl_context_from_options() if protocol_version >= 3: - self.max_request_id = (2 ** 15) - 1 - # Don't fill the deque with 2**15 items right away. Start with 300 and add + self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) + # Don't fill the deque with 2**15 items right away. Start with some and add # more if needed. - self.request_ids = deque(range(300)) - self.highest_request_id = 299 + initial_size = min(300, self.max_in_flight) + self.request_ids = deque(range(initial_size)) + self.highest_request_id = initial_size - 1 else: - self.max_request_id = (2 ** 7) - 1 + self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1) self.request_ids = deque(range(self.max_request_id + 1)) self.highest_request_id = self.max_request_id self.lock = RLock() self.connected_event = Event() + @property + def host(self): + return self.endpoint.address + + @property + def port(self): + return self.endpoint.port + @classmethod def initialize_reactor(cls): """ @@ -285,7 +838,7 @@ def initialize_reactor(cls): @classmethod def handle_fork(cls): """ - Called after a forking. This should cleanup any remaining reactor state + Called after a forking. This should clean up any remaining reactor state from the parent process. """ pass @@ -295,7 +848,7 @@ def create_timer(cls, timeout, callback): raise NotImplementedError() @classmethod - def factory(cls, host, timeout, *args, **kwargs): + def factory(cls, endpoint, timeout, *args, **kwargs): """ A factory function which returns connections which have succeeded in connecting and are ready for service (or @@ -303,12 +856,12 @@ def factory(cls, host, timeout, *args, **kwargs): """ start = time.time() kwargs['connect_timeout'] = timeout - conn = cls(host, *args, **kwargs) + conn = cls(endpoint, *args, **kwargs) elapsed = time.time() - start conn.connected_event.wait(timeout - elapsed) if conn.last_error: if conn.is_unsupported_proto_version: - raise ProtocolVersionUnsupported(host, conn.protocol_version) + raise ProtocolVersionUnsupported(endpoint, conn.protocol_version) raise conn.last_error elif not conn.connected_event.is_set(): conn.close() @@ -316,18 +869,89 @@ def factory(cls, host, timeout, *args, **kwargs): else: return conn + def _build_ssl_context_from_options(self): + + # Extract a subset of names from self.ssl_options which apply to SSLContext creation + ssl_context_opt_names = ['ssl_version', 'cert_reqs', 'check_hostname', 'keyfile', 'certfile', 'ca_certs', 'ciphers'] + opts = {k:self.ssl_options.get(k, None) for k in ssl_context_opt_names if k in self.ssl_options} + + # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER, so we'll get ahead of things by always + # being explicit + ssl_version = opts.get('ssl_version', None) or ssl.PROTOCOL_TLS_CLIENT + cert_reqs = opts.get('cert_reqs', None) or ssl.CERT_REQUIRED + rv = ssl.SSLContext(protocol=int(ssl_version)) + rv.check_hostname = bool(opts.get('check_hostname', False)) + rv.options = int(cert_reqs) + + certfile = opts.get('certfile', None) + keyfile = opts.get('keyfile', None) + if certfile: + rv.load_cert_chain(certfile, keyfile) + ca_certs = opts.get('ca_certs', None) + if ca_certs: + rv.load_verify_locations(ca_certs) + ciphers = opts.get('ciphers', None) + if ciphers: + rv.set_ciphers(ciphers) + + return rv + + def _wrap_socket_from_context(self): + + # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts + # of it that don't involve building an SSLContext under the covers) + wrap_socket_opt_names = ['server_side', 'do_handshake_on_connect', 'suppress_ragged_eofs', 'server_hostname'] + opts = {k:self.ssl_options.get(k, None) for k in wrap_socket_opt_names if k in self.ssl_options} + + # PYTHON-1186: set the server_hostname only if the SSLContext has + # check_hostname enabled, and it is not already provided by the EndPoint ssl options + #opts['server_hostname'] = self.endpoint.address + if (self.ssl_context.check_hostname and 'server_hostname' not in opts): + server_hostname = self.endpoint.address + opts['server_hostname'] = server_hostname + + return self.ssl_context.wrap_socket(self._socket, **opts) + + def _initiate_connection(self, sockaddr): + self._socket.connect(sockaddr) + + # PYTHON-1331 + # + # Allow implementations specific to an event loop to add additional behaviours + def _validate_hostname(self): + pass + + def _get_socket_addresses(self): + address, port = self.endpoint.resolve() + + if hasattr(socket, 'AF_UNIX') and self.endpoint.socket_family == socket.AF_UNIX: + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, address)] + + addresses = socket.getaddrinfo(address, port, self.endpoint.socket_family, socket.SOCK_STREAM) + if not addresses: + raise ConnectionException("getaddrinfo returned empty list for %s" % (self.endpoint,)) + + return addresses + def _connect_socket(self): sockerr = None - addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - for (af, socktype, proto, canonname, sockaddr) in addresses: + addresses = self._get_socket_addresses() + for (af, socktype, proto, _, sockaddr) in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) - if self.ssl_options: - if not self._ssl_impl: - raise Exception("This version of Python was not compiled with SSL support") - self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) + if self.ssl_context: + self._socket = self._wrap_socket_from_context() self._socket.settimeout(self.connect_timeout) - self._socket.connect(sockaddr) + self._initiate_connection(sockaddr) + self._socket.settimeout(None) + + # PYTHON-1331 + # + # Most checking is done via the check_hostname param on the SSLContext. + # Subclasses can add additional behaviours via _validate_hostname() so + # run that here. + if self._check_hostname: + self._validate_hostname() sockerr = None break except socket.error as err: @@ -337,12 +961,23 @@ def _connect_socket(self): sockerr = err if sockerr: - raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror or sockerr)) + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % + ([a[4] for a in addresses], sockerr.strerror or sockerr)) if self.sockopts: for args in self.sockopts: self._socket.setsockopt(*args) + def _enable_compression(self): + if self._compressor: + self.compressor = self._compressor + + def _enable_checksumming(self): + self._io_buffer.set_checksumming_buffer() + self._is_checksumming_enabled = True + self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression + log.debug("Enabling protocol checksumming on connection (%s).", id(self)) + def close(self): raise NotImplementedError() @@ -356,17 +991,23 @@ def defunct(self, exc): # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message if any(exc_info): log.debug("Defuncting connection (%s) to %s:", - id(self), self.host, exc_info=exc_info) + id(self), self.endpoint, exc_info=exc_info) else: log.debug("Defuncting connection (%s) to %s: %s", - id(self), self.host, exc) + id(self), self.endpoint, exc) self.last_error = exc self.close() + self.error_all_cp_sessions(exc) self.error_all_requests(exc) self.connected_event.set() return exc + def error_all_cp_sessions(self, exc): + stream_ids = list(self._continuous_paging_sessions.keys()) + for stream_id in stream_ids: + self._continuous_paging_sessions[stream_id].on_error(exc) + def error_all_requests(self, exc): with self.lock: requests = self._requests @@ -376,16 +1017,17 @@ def error_all_requests(self, exc): return new_exc = ConnectionShutdown(str(exc)) + def try_callback(cb): try: cb(new_exc) except Exception: log.warning("Ignoring unhandled exception while erroring requests for a " "failed connection (%s) to host %s:", - id(self), self.host, exc_info=True) + id(self), self.endpoint, exc_info=True) # run first callback from this thread to ensure pool state before leaving - cb, _ = requests.popitem()[1] + cb, _, _ = requests.popitem()[1] try_callback(cb) if not requests: @@ -395,7 +1037,7 @@ def try_callback(cb): # The default callback and retry logic is fairly expensive -- we don't # want to tie up the event thread when there are many requests def err_all_callbacks(): - for cb, _ in requests.values(): + for cb, _, _ in requests.values(): try_callback(cb) if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() @@ -413,9 +1055,10 @@ def get_request_id(self): try: return self.request_ids.popleft() except IndexError: - self.highest_request_id += 1 + new_request_id = self.highest_request_id + 1 # in_flight checks should guarantee this - assert self.highest_request_id <= self.max_request_id + assert new_request_id <= self.max_request_id + self.highest_request_id = new_request_id return self.highest_request_id def handle_pushed(self, response): @@ -426,20 +1069,30 @@ def handle_pushed(self, response): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message): + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): if self.is_defunct: - raise ConnectionShutdown("Connection to %s is defunct" % self.host) + raise ConnectionShutdown("Connection to %s is defunct" % self.endpoint) elif self.is_closed: - raise ConnectionShutdown("Connection to %s is closed" % self.host) + raise ConnectionShutdown("Connection to %s is closed" % self.endpoint) + elif not self._socket_writable: + raise ConnectionBusy("Connection %s is overloaded" % self.endpoint) # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages - self._requests[request_id] = (cb, decoder) - self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) - return request_id + self._requests[request_id] = (cb, decoder, result_metadata) + msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version) + + if self._is_checksumming_enabled: + buffer = io.BytesIO() + self._segment_codec.encode(buffer, msg) + msg = buffer.getvalue() + + self.push(msg) + return len(msg) - def wait_for_response(self, msg, timeout=None): - return self.wait_for_responses(msg, timeout=timeout)[0] + def wait_for_response(self, msg, timeout=None, **kwargs): + return self.wait_for_responses(msg, timeout=timeout, **kwargs)[0] def wait_for_responses(self, *msgs, **kwargs): """ @@ -461,7 +1114,7 @@ def wait_for_responses(self, *msgs, **kwargs): while True: needed = len(msgs) - messages_sent with self.lock: - available = min(needed, self.max_request_id - self.in_flight) + available = min(needed, self.max_request_id - self.in_flight + 1) request_ids = [self.get_request_id() for _ in range(available)] self.in_flight += available @@ -513,11 +1166,11 @@ def control_conn_disposed(self): @defunct_on_error def _read_frame_header(self): - buf = self._iobuf.getvalue() + buf = self._io_buffer.cql_frame_buffer.getvalue() pos = len(buf) if pos: - version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK - if version > MAX_SUPPORTED_VERSION: + version = buf[0] & PROTOCOL_VERSION_MASK + if version not in ProtocolVersion.SUPPORTED_VERSIONS: raise ProtocolError("This version of the driver does not support protocol version %d" % version) frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2 # this frame header struct is everything after the version byte @@ -529,46 +1182,99 @@ def _read_frame_header(self): self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) return pos - def _reset_frame(self): - self._iobuf = io.BytesIO(self._iobuf.read()) - self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6) - self._current_frame = None + @defunct_on_error + def _process_segment_buffer(self): + readable_bytes = self._io_buffer.readable_io_bytes() + if readable_bytes >= self._segment_codec.header_length_with_crc: + try: + self._io_buffer.io_buffer.seek(0) + segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer) + + if readable_bytes >= segment_header.segment_length: + segment = self._segment_codec.decode(self._iobuf, segment_header) + self._io_buffer._segment_consumed = True + self._io_buffer.cql_frame_buffer.write(segment.payload) + else: + # not enough data to read the segment. reset the buffer pointer at the + # beginning to not lose what we previously read (header). + self._io_buffer._segment_consumed = False + self._io_buffer.io_buffer.seek(0) + except CrcException as exc: + # re-raise an exception that inherits from ConnectionException + raise CrcMismatchException(str(exc), self.endpoint) + else: + self._io_buffer._segment_consumed = False def process_io_buffer(self): while True: + if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + self._process_segment_buffer() + self._io_buffer.reset_io_buffer() + + if self._is_checksumming_enabled and not self._io_buffer.has_consumed_segment: + # We couldn't read an entire segment from the io buffer, so return + # control to allow more bytes to be read off the wire + return + if not self._current_frame: pos = self._read_frame_header() else: - pos = self._iobuf.tell() + pos = self._io_buffer.readable_cql_frame_bytes() if not self._current_frame or pos < self._current_frame.end_pos: - # we don't have a complete header yet or we + if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes(): + # We have a multi-segments message, and we need to read more + # data to complete the current cql frame + continue + + # we don't have a complete header yet, or we # already saw a header, but we don't have a # complete message yet return else: frame = self._current_frame - self._iobuf.seek(frame.body_offset) - msg = self._iobuf.read(frame.end_pos - frame.body_offset) + self._io_buffer.cql_frame_buffer.seek(frame.body_offset) + msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) self.process_msg(frame, msg) - self._reset_frame() + self._io_buffer.reset_cql_frame_buffer() + self._current_frame = None @defunct_on_error def process_msg(self, header, body): + self.msg_received = True stream_id = header.stream if stream_id < 0: callback = None decoder = ProtocolHandler.decode_message + result_metadata = None else: - callback, decoder = self._requests.pop(stream_id, None) - with self.lock: - self.request_ids.append(stream_id) + if stream_id in self._continuous_paging_sessions: + paging_session = self._continuous_paging_sessions[stream_id] + callback = paging_session.on_message + decoder = paging_session.decoder + result_metadata = None + else: + need_notify_of_release = False + with self.lock: + if stream_id in self.orphaned_request_ids: + self.in_flight -= 1 + self.orphaned_request_ids.remove(stream_id) + need_notify_of_release = True + if need_notify_of_release and self._on_orphaned_stream_released: + self._on_orphaned_stream_released() - self.msg_received = True + try: + callback, decoder, result_metadata = self._requests.pop(stream_id) + # This can only happen if the stream_id was + # removed due to an OperationTimedOut + except KeyError: + with self.lock: + self.request_ids.append(stream_id) + return try: response = decoder(header.version, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor) + header.flags, header.opcode, body, self.decompressor, result_metadata) except Exception as exc: log.exception("Error decoding response from Cassandra. " "%s; buffer: %r", header, self._iobuf.getvalue()) @@ -582,8 +1288,8 @@ def process_msg(self, header, body): if isinstance(response, ProtocolException): if 'unsupported protocol version' in response.message: self.is_unsupported_proto_version = True - - log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) + else: + log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) self.defunct(response) if callback is not None: callback(response) @@ -592,18 +1298,33 @@ def process_msg(self, header, body): except Exception: log.exception("Callback handler errored, ignoring:") + # done after callback because the callback might signal this as a paging session + if stream_id >= 0: + if stream_id in self._continuous_paging_sessions: + if self._continuous_paging_sessions[stream_id].released: + self.remove_continuous_paging_session(stream_id) + else: + with self.lock: + self.request_ids.append(stream_id) + + def new_continuous_paging_session(self, stream_id, decoder, row_factory, state): + session = ContinuousPagingSession(stream_id, decoder, row_factory, self, state) + self._continuous_paging_sessions[stream_id] = session + return session + + def remove_continuous_paging_session(self, stream_id): + try: + self._continuous_paging_sessions.pop(stream_id) + with self.lock: + log.debug("Returning cp session stream id %s", stream_id) + self.request_ids.append(stream_id) + except KeyError: + pass + @defunct_on_error def _send_options_message(self): - if self.cql_version is None and (not self.compression or not locally_supported_compressions): - log.debug("Not sending options message for new connection(%s) to %s " - "because compression is disabled and a cql version was not " - "specified", id(self), self.host) - self._compressor = None - self.cql_version = DEFAULT_CQL_VERSION - self._send_startup_message() - else: - log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) - self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) + log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.endpoint) + self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): @@ -621,9 +1342,10 @@ def _handle_options_response(self, options_response): % (options_response,)) log.debug("Received options response on new connection (%s) from %s", - id(self), self.host) + id(self), self.endpoint) supported_cql_versions = options_response.cql_versions remote_supported_compressions = options_response.options['COMPRESSION'] + self._product_type = options_response.options.get('PRODUCT_TYPE', [None])[0] if self.cql_version: if self.cql_version not in supported_cql_versions: @@ -646,12 +1368,12 @@ def _handle_options_response(self, options_response): remote_supported_compressions) else: compression_type = None - if isinstance(self.compression, six.string_types): + if isinstance(self.compression, str): # the user picked a specific compression type ('snappy' or 'lz4') if self.compression not in remote_supported_compressions: raise ProtocolError( "The requested compression type (%s) is not supported by the Cassandra server at %s" - % (self.compression, self.host)) + % (self.compression, self.endpoint)) compression_type = self.compression else: # our locally supported compressions are ordered to prefer @@ -661,19 +1383,31 @@ def _handle_options_response(self, options_response): compression_type = k break - # set the decompressor here, but set the compressor only after - # a successful Ready message - self._compressor, self.decompressor = \ - locally_supported_compressions[compression_type] + # If snappy compression is selected with v5+checksumming, the connection + # will fail with OTO. Only lz4 is supported + if (compression_type == 'snappy' and + ProtocolVersion.has_checksumming_support(self.protocol_version)): + log.debug("Snappy compression is not supported with protocol version %s and " + "checksumming. Consider installing lz4. Disabling compression.", self.protocol_version) + compression_type = None + else: + # set the decompressor here, but set the compressor only after + # a successful Ready message + self._compression_type = compression_type + self._compressor, self.decompressor = \ + locally_supported_compressions[compression_type] - self._send_startup_message(compression_type) + self._send_startup_message(compression_type, no_compact=self.no_compact) @defunct_on_error - def _send_startup_message(self, compression=None): + def _send_startup_message(self, compression=None, no_compact=False): log.debug("Sending StartupMessage on %s", self) - opts = {} + opts = {'DRIVER_NAME': DRIVER_NAME, + 'DRIVER_VERSION': DRIVER_VERSION} if compression: opts['COMPRESSION'] = compression + if no_compact: + opts['NO_COMPACT'] = 'true' sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) log.debug("Sent StartupMessage on %s", self) @@ -682,17 +1416,34 @@ def _send_startup_message(self, compression=None): def _handle_startup_response(self, startup_response, did_authenticate=False): if self.is_defunct: return + if isinstance(startup_response, ReadyMessage): - log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.host) - if self._compressor: - self.compressor = self._compressor + if self.authenticator: + log.warning("An authentication challenge was not sent, " + "this is suspicious because the driver expects " + "authentication (configured authenticator = %s)", + self.authenticator.__class__.__name__) + + log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) + self._enable_compression() + + if ProtocolVersion.has_checksumming_support(self.protocol_version): + self._enable_checksumming() + self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", - id(self), self.host, startup_response.authenticator) + id(self), self.endpoint, startup_response.authenticator) if self.authenticator is None: - raise AuthenticationFailed('Remote end requires authentication.') + log.error("Failed to authenticate to %s. If you are trying to connect to a DSE cluster, " + "consider using TransitionalModePlainTextAuthProvider " + "if DSE authentication is configured with transitional mode" % (self.host,)) + raise AuthenticationFailed('Remote end requires authentication') + + self._enable_compression() + if ProtocolVersion.has_checksumming_support(self.protocol_version): + self._enable_checksumming() if isinstance(self.authenticator, dict): log.debug("Sending credentials-based auth response on %s", self) @@ -704,20 +1455,21 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.authenticator.server_authenticator_class = startup_response.authenticator initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response - self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response) + self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), + self._handle_auth_response) elif isinstance(startup_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.host, startup_response.summary_msg()) + id(self), self.endpoint, startup_response.summary_msg()) if did_authenticate: raise AuthenticationFailed( "Failed to authenticate to %s: %s" % - (self.host, startup_response.summary_msg())) + (self.endpoint, startup_response.summary_msg())) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" - % (self.host, startup_response.summary_msg())) + % (self.endpoint, startup_response.summary_msg())) elif isinstance(startup_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the startup handshake", (self.host)) + log.debug("Connection to %s was closed during the startup handshake", (self.endpoint)) raise startup_response else: msg = "Unexpected response during Connection setup: %r" @@ -742,17 +1494,17 @@ def _handle_auth_response(self, auth_response): self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", - id(self), self.host, auth_response.summary_msg()) + id(self), self.endpoint, auth_response.summary_msg()) raise AuthenticationFailed( "Failed to authenticate to %s: %s" % - (self.host, auth_response.summary_msg())) + (self.endpoint, auth_response.summary_msg())) elif isinstance(auth_response, ConnectionShutdown): - log.debug("Connection to %s was closed during the authentication process", self.host) + log.debug("Connection to %s was closed during the authentication process", self.endpoint) raise auth_response else: msg = "Unexpected response during Connection authentication to %s: %r" - log.error(msg, self.host, auth_response) - raise ProtocolError(msg % (self.host, auth_response)) + log.error(msg, self.endpoint, auth_response) + raise ProtocolError(msg % (self.endpoint, auth_response)) def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: @@ -767,7 +1519,7 @@ def set_keyspace_blocking(self, keyspace): raise ire.to_exception() except Exception as exc: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.host) + "Problem while setting keyspace: %r" % (exc,), self.endpoint) self.defunct(conn_exc) raise conn_exc @@ -775,7 +1527,7 @@ def set_keyspace_blocking(self, keyspace): self.keyspace = keyspace else: conn_exc = ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host) + "Problem while setting keyspace: %r" % (result,), self.endpoint) self.defunct(conn_exc) raise conn_exc @@ -785,7 +1537,29 @@ def set_keyspace_async(self, keyspace, callback): When the operation completes, `callback` will be called with two arguments: this connection and an Exception if an error occurred, otherwise :const:`None`. + + This method will always increment :attr:`.in_flight` attribute, even if + it doesn't need to make a request, just to maintain an + ":attr:`.in_flight` is incremented" invariant. """ + # Here we increment in_flight unconditionally, whether we need to issue + # a request or not. This is bad, but allows callers -- specifically + # _set_keyspace_for_all_conns -- to assume that we increment + # self.in_flight during this call. This allows the passed callback to + # safely call HostConnection{Pool,}.return_connection on this + # Connection. + # + # We use a busy wait on the lock here because: + # - we'll only spin if the connection is at max capacity, which is very + # unlikely for a set_keyspace call + # - it allows us to avoid signaling a condition every time a request completes + while True: + with self.lock: + if self.in_flight < self.max_request_id: + self.in_flight += 1 + break + time.sleep(0.001) + if not keyspace or keyspace == self.keyspace: callback(self, None) return @@ -801,21 +1575,11 @@ def process_result(result): callback(self, result.to_exception()) else: callback(self, self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host))) + "Problem while setting keyspace: %r" % (result,), self.endpoint))) - request_id = None - # we use a busy wait on the lock here because: - # - we'll only spin if the connection is at max capacity, which is very - # unlikely for a set_keyspace call - # - it allows us to avoid signaling a condition every time a request completes - while True: - with self.lock: - if self.in_flight < self.max_request_id: - request_id = self.get_request_id() - self.in_flight += 1 - break - - time.sleep(0.001) + # We've incremented self.in_flight above, so we "have permission" to + # acquire a new request id + request_id = self.get_request_id() self.send_msg(query, request_id, process_result) @@ -826,18 +1590,6 @@ def is_idle(self): def reset_idle(self): self.msg_received = False - @property - def server_version(self): - if self._server_version is None: - query_message = QueryMessage(query="SELECT release_version FROM system.local", consistency_level=ConsistencyLevel.ONE) - message = self.wait_for_response(query_message) - self._server_version = message.results[1][0][0] # (col names, rows)[rows][first row][only item] - return self._server_version - - @server_version.setter - def server_version(self, version): - self._server_version = version - def __str__(self): status = "" if self.is_defunct: @@ -845,7 +1597,7 @@ def __str__(self): elif self.is_closed: status = " (closed)" - return "<%s(%r) %s:%d%s>" % (self.__class__.__name__, id(self), self.host, self.port, status) + return "<%s(%r) %s%s>" % (self.__class__.__name__, id(self), self.endpoint, status) __repr__ = __str__ @@ -906,7 +1658,7 @@ def __init__(self, connection, owner): self.connection = connection self.owner = owner log.debug("Sending options message heartbeat on idle connection (%s) %s", - id(connection), connection.host) + id(connection), connection.endpoint) with connection.lock: if connection.in_flight < connection.max_request_id: connection.in_flight += 1 @@ -921,26 +1673,27 @@ def wait(self, timeout): if self._exception: raise self._exception else: - raise OperationTimedOut() + raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.endpoint) def _options_callback(self, response): - if not isinstance(response, SupportedMessage): + if isinstance(response, SupportedMessage): + log.debug("Received options response on connection (%s) from %s", + id(self.connection), self.connection.endpoint) + else: if isinstance(response, ConnectionException): self._exception = response else: self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" % (response,)) - - log.debug("Received options response on connection (%s) from %s", - id(self.connection), self.connection.host) self._event.set() class ConnectionHeartbeat(Thread): - def __init__(self, interval_sec, get_connection_holders): + def __init__(self, interval_sec, get_connection_holders, timeout): Thread.__init__(self, name="Connection heartbeat") self._interval = interval_sec + self._timeout = timeout self._get_connection_holders = get_connection_holders self._shutdown_event = Event() self.daemon = True @@ -964,34 +1717,44 @@ def run(self): if connection.is_idle: try: futures.append(HeartbeatFuture(connection, owner)) - except Exception: + except Exception as e: log.warning("Failed sending heartbeat message on connection (%s) to %s", - id(connection), connection.host, exc_info=True) - failed_connections.append((connection, owner)) + id(connection), connection.endpoint) + failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: - # make sure the owner sees this defunt/closed connection + log.debug("Cannot send heartbeat message on connection (%s) to %s", + id(connection), connection.endpoint) + # make sure the owner sees this defunct/closed connection owner.return_connection(connection) self._raise_if_stopped() + # Wait max `self._timeout` seconds for all HeartbeatFutures to complete + timeout = self._timeout + start_time = time.time() for f in futures: self._raise_if_stopped() connection = f.connection try: - f.wait(self._interval) + f.wait(timeout) # TODO: move this, along with connection locks in pool, down into Connection with connection.lock: connection.in_flight -= 1 connection.reset_idle() - except Exception: + except Exception as e: log.warning("Heartbeat failed for connection (%s) to %s", - id(connection), connection.host, exc_info=True) - failed_connections.append((f.connection, f.owner)) + id(connection), connection.endpoint) + failed_connections.append((f.connection, f.owner, e)) - for connection, owner in failed_connections: + timeout = self._timeout - (time.time() - start_time) + + for connection, owner, exc in failed_connections: self._raise_if_stopped() - connection.defunct(Exception('Connection heartbeat failure')) + if not connection.is_control_connection: + # Only HostConnection supports shutdown_on_error + owner.shutdown_on_error = True + connection.defunct(exc) owner.return_connection(connection) except self.ShutdownException: pass @@ -1017,8 +1780,6 @@ class Timer(object): def __init__(self, timeout, callback): self.end = time.time() + timeout self.callback = callback - if timeout < 0: - self.callback() def __lt__(self, other): return self.end < other.end diff --git a/cassandra/cqlengine/__init__.py b/cassandra/cqlengine/__init__.py index 48b9da7bfb..200d04b831 100644 --- a/cassandra/cqlengine/__init__.py +++ b/cassandra/cqlengine/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,9 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - - # Caching constants. CACHING_ALL = "ALL" CACHING_KEYS_ONLY = "KEYS_ONLY" @@ -31,7 +30,4 @@ class ValidationError(CQLEngineException): class UnicodeMixin(object): - if six.PY3: - __str__ = lambda x: x.__unicode__() - else: - __str__ = lambda x: six.text_type(x).encode('utf-8') + __str__ = lambda x: x.__unicode__() diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index c441b8f23b..7d50687d95 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +15,15 @@ # limitations under the License. from copy import deepcopy, copy -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone import logging -import six from uuid import UUID as _UUID from cassandra import util -from cassandra.cqltypes import SimpleDateType +from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType from cassandra.cqlengine import ValidationError from cassandra.cqlengine.functions import get_total_seconds +from cassandra.util import Duration as _Duration log = logging.getLogger(__name__) @@ -37,7 +39,7 @@ def __init__(self, instance, column, value): @property def deleted(self): - return self.column._val_is_null(self.value) and (self.explicit or self.previous_value is not None) + return self.column._val_is_null(self.value) and (self.explicit or not self.column._val_is_null(self.previous_value)) @property def changed(self): @@ -47,7 +49,19 @@ def changed(self): :rtype: boolean """ - return self.value != self.previous_value + if self.explicit: + return self.value != self.previous_value + + if isinstance(self.column, BaseContainerColumn): + default_value = self.column.get_default() + if self.column._val_is_null(default_value): + return not self.column._val_is_null(self.value) and self.value != self.previous_value + elif self.previous_value is None: + return self.value != default_value + + return self.value != self.previous_value + + return False def reset_previous_value(self): self.previous_value = deepcopy(self.value) @@ -57,6 +71,7 @@ def getval(self): def setval(self, val): self.value = val + self.explicit = True def delval(self): self.value = None @@ -100,6 +115,13 @@ class Column(object): bool flag, indicates an index should be created for this column """ + custom_index = False + """ + bool flag, indicates an index is managed outside of cqlengine. This is + useful if you want to do filter queries on fields that have custom + indexes. + """ + db_field = None """ the fieldname this field will map to in the database @@ -147,10 +169,12 @@ def __init__(self, required=False, clustering_order=None, discriminator_column=False, - static=False): + static=False, + custom_index=False): self.partition_key = partition_key self.primary_key = partition_key or primary_key self.index = index + self.custom_index = custom_index self.db_field = db_field self.default = default self.required = required @@ -159,6 +183,7 @@ def __init__(self, # the column name in the model definition self.column_name = None + self._partition_key_index = None self.static = static self.value = None @@ -167,6 +192,39 @@ def __init__(self, self.position = Column.instance_counter Column.instance_counter += 1 + def __ne__(self, other): + if isinstance(other, Column): + return self.position != other.position + return NotImplemented + + def __eq__(self, other): + if isinstance(other, Column): + return self.position == other.position + return NotImplemented + + def __lt__(self, other): + if isinstance(other, Column): + return self.position < other.position + return NotImplemented + + def __le__(self, other): + if isinstance(other, Column): + return self.position <= other.position + return NotImplemented + + def __gt__(self, other): + if isinstance(other, Column): + return self.position > other.position + return NotImplemented + + def __ge__(self, other): + if isinstance(other, Column): + return self.position >= other.position + return NotImplemented + + def __hash__(self): + return id(self) + def validate(self, value): """ Returns a cleaned and validated value. Raises a ValidationError @@ -188,8 +246,6 @@ def to_database(self, value): """ Converts python value into database value """ - if value is None and self.has_default: - return self.get_default() return value @property @@ -233,13 +289,17 @@ def set_column_name(self, name): @property def db_field_name(self): """ Returns the name of the cql name of this column """ - return self.db_field or self.column_name + return self.db_field if self.db_field is not None else self.column_name @property def db_index_name(self): """ Returns the name of the cql index """ return 'index_{0}'.format(self.db_field_name) + @property + def has_index(self): + return self.index or self.custom_index + @property def cql(self): return self.get_cql() @@ -255,6 +315,10 @@ def _val_is_null(self, val): def sub_types(self): return [] + @property + def cql_type(self): + return _cqltypes[self.db_type] + class Blob(Column): """ @@ -264,7 +328,7 @@ class Blob(Column): def to_database(self, value): - if not isinstance(value, (six.binary_type, bytearray)): + if not isinstance(value, (bytes, bytearray)): raise Exception("expecting a binary, got a %s" % type(value)) val = super(Bytes, self).to_database(value) @@ -274,13 +338,6 @@ def to_database(self, value): Bytes = Blob -class Ascii(Column): - """ - Stores a US-ASCII character string - """ - db_type = 'ascii' - - class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format @@ -300,25 +357,68 @@ def __init__(self, min_length=None, max_length=None, **kwargs): Defaults to 1 if this is a ``required`` column. Otherwise, None. :param int max_length: Sets the maximum length of this string, for validation purposes. """ - self.min_length = min_length or (1 if kwargs.get('required', False) else None) + self.min_length = ( + 1 if min_length is None and kwargs.get('required', False) + else min_length) self.max_length = max_length + + if self.min_length is not None: + if self.min_length < 0: + raise ValueError( + 'Minimum length is not allowed to be negative.') + + if self.max_length is not None: + if self.max_length < 0: + raise ValueError( + 'Maximum length is not allowed to be negative.') + + if self.min_length is not None and self.max_length is not None: + if self.max_length < self.min_length: + raise ValueError( + 'Maximum length must be greater or equal ' + 'to minimum length.') + super(Text, self).__init__(**kwargs) def validate(self, value): value = super(Text, self).validate(value) - if value is None: - return - if not isinstance(value, (six.string_types, bytearray)) and value is not None: + if not isinstance(value, (str, bytearray)) and value is not None: raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) - if self.max_length: - if len(value) > self.max_length: + if self.max_length is not None: + if value and len(value) > self.max_length: raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) if self.min_length: - if len(value) < self.min_length: + if (self.min_length and not value) or len(value) < self.min_length: raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value +class Ascii(Text): + """ + Stores a US-ASCII character string + """ + db_type = 'ascii' + + def validate(self, value): + """ Only allow ASCII and None values. + + Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. + the Basic Latin block of the Unicode character set. + + Source: https://github.com/apache/cassandra/blob + /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra + /serializers/AsciiSerializer.java#L29 + """ + value = super(Ascii, self).validate(value) + if value: + charset = value if isinstance( + value, (bytearray, )) else map(ord, value) + if not set(range(128)).issuperset(charset): + raise ValidationError( + '{!r} is not an ASCII string.'.format(value)) + return value + + class Integer(Column): """ Stores a 32-bit signed integer value @@ -403,7 +503,7 @@ def __init__(self, instance, column, value): class Counter(Integer): """ - Stores a counter that can be inremented and decremented + Stores a counter that can be incremented and decremented """ db_type = 'counter' @@ -453,7 +553,7 @@ def to_python(self, value): elif isinstance(value, date): return datetime(*(value.timetuple()[:6])) - return datetime.utcfromtimestamp(value) + return datetime.fromtimestamp(value, tz=timezone.utc).replace(tzinfo=None) def to_database(self, value): value = super(DateTime, self).to_database(value) @@ -483,7 +583,6 @@ class Date(Column): db_type = 'date' def to_database(self, value): - value = super(Date, self).to_database(value) if value is None: return @@ -492,6 +591,14 @@ def to_database(self, value): d = value if isinstance(value, util.Date) else util.Date(value) return d.days_from_epoch + SimpleDateType.EPOCH_OFFSET_DAYS + def to_python(self, value): + if value is None: + return + if isinstance(value, util.Date): + return value + if isinstance(value, datetime): + value = value.date() + return util.Date(value) class Time(Column): """ @@ -510,6 +617,32 @@ def to_database(self, value): # str(util.Time) yields desired CQL encoding return value if isinstance(value, util.Time) else util.Time(value) + def to_python(self, value): + value = super(Time, self).to_database(value) + if value is None: + return + if isinstance(value, util.Time): + return value + return util.Time(value) + +class Duration(Column): + """ + Stores a duration (months, days, nanoseconds) + + .. versionadded:: 3.10.0 + + requires C* 3.10+ and protocol v4+ + """ + db_type = 'duration' + + def validate(self, value): + val = super(Duration, self).validate(value) + if val is None: + return + if not isinstance(val, _Duration): + raise TypeError('{0} {1} is not a valid Duration.'.format(self.column_name, value)) + return val + class UUID(Column): """ @@ -523,7 +656,7 @@ def validate(self, value): return if isinstance(val, _UUID): return val - if isinstance(val, six.string_types): + if isinstance(val, str): try: return _UUID(val) except ValueError: @@ -625,7 +758,7 @@ class BaseCollectionColumn(Column): """ Base Container type for collection-like columns. - https://cassandra.apache.org/doc/cql3/CQL.html#collections + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections """ def __init__(self, types, **kwargs): """ @@ -665,6 +798,10 @@ def _freeze_db_type(self): def sub_types(self): return self.types + @property + def cql_type(self): + return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) + class Tuple(BaseCollectionColumn): """ @@ -797,7 +934,7 @@ class Map(BaseContainerColumn): """ Stores a key -> value map (dictionary) - http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html + https://docs.datastax.com/en/dse/6.7/cql/cql/cql_using/useMap.html """ _python_type_hashable = False @@ -843,7 +980,16 @@ def to_database(self, value): class UDTValueManager(BaseValueManager): @property def changed(self): - return self.value != self.previous_value or (self.value is not None and self.value.has_changed_fields()) + if self.explicit: + return self.value != self.previous_value + + default_value = self.column.get_default() + if not self.column._val_is_null(default_value): + return self.value != default_value + elif self.previous_value is None: + return not self.column._val_is_null(self.value) and self.value.has_changed_fields() + + return False def reset_previous_value(self): if self.value is not None: @@ -876,6 +1022,41 @@ def __init__(self, user_type, **kwargs): def sub_types(self): return list(self.user_type._fields.values()) + @property + def cql_type(self): + return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(), + field_names=[c.db_field_name for c in self.user_type._fields.values()], + field_types=[c.cql_type for c in self.user_type._fields.values()]) + + def validate(self, value): + val = super(UserDefinedType, self).validate(value) + if val is None: + return + val.validate() + return val + + def to_python(self, value): + if value is None: + return + + copied_value = deepcopy(value) + for name, field in self.user_type._fields.items(): + if copied_value[name] is not None or isinstance(field, BaseContainerColumn): + copied_value[name] = field.to_python(copied_value[name]) + + return copied_value + + def to_database(self, value): + if value is None: + return + + copied_value = deepcopy(value) + for name, field in self.user_type._fields.items(): + if copied_value[name] is not None or isinstance(field, BaseContainerColumn): + copied_value[name] = field.to_database(copied_value[name]) + + return copied_value + def resolve_udts(col_def, out_list): for col in col_def.sub_types: @@ -891,7 +1072,7 @@ class _PartitionKeysToken(Column): """ def __init__(self, model): - self.partition_columns = model._partition_keys.values() + self.partition_columns = list(model._partition_keys.values()) super(_PartitionKeysToken, self).__init__(partition_key=True) @property diff --git a/cassandra/cqlengine/connection.py b/cassandra/cqlengine/connection.py index a737f29e65..55437d7b7f 100644 --- a/cassandra/cqlengine/connection.py +++ b/cassandra/cqlengine/connection.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple, defaultdict +from collections import defaultdict import logging -import six import threading -from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist -from cassandra.query import SimpleStatement, Statement, dict_factory +from cassandra.cluster import Cluster, _ConfigMode, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist, ConsistencyLevel +from cassandra.query import SimpleStatement, dict_factory from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.statements import BaseCQLStatement @@ -28,13 +29,12 @@ NOT_SET = _NOT_SET # required for passing timeout to Session.execute -Host = namedtuple('Host', ['name', 'port']) - cluster = None session = None -lazy_connect_args = None -lazy_connect_lock = threading.RLock() +# connections registry +DEFAULT_CONNECTION = object() +_connections = {} # Because type models may be registered before a connection is present, # and because sessions may be replaced, we must register UDTs here, in order @@ -42,55 +42,274 @@ udt_by_keyspace = defaultdict(dict) +def format_log_context(msg, connection=None, keyspace=None): + """Format log message to add keyspace and connection context""" + connection_info = connection or 'DEFAULT_CONNECTION' + + if keyspace: + msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg) + else: + msg = '[Connection: {0}] {1}'.format(connection_info, msg) + return msg + + class UndefinedKeyspaceException(CQLEngineException): pass -def default(): +class Connection(object): + """CQLEngine Connection""" + + name = None + hosts = None + + consistency = None + retry_connect = False + lazy_connect = False + lazy_connect_lock = None + cluster_options = None + + cluster = None + session = None + + def __init__(self, name, hosts, consistency=None, + lazy_connect=False, retry_connect=False, cluster_options=None): + self.hosts = hosts + self.name = name + self.consistency = consistency + self.lazy_connect = lazy_connect + self.retry_connect = retry_connect + self.cluster_options = cluster_options if cluster_options else {} + self.lazy_connect_lock = threading.RLock() + + @classmethod + def from_session(cls, name, session): + instance = cls(name=name, hosts=session.hosts) + instance.cluster, instance.session = session.cluster, session + instance.setup_session() + return instance + + def setup(self): + """Set up the connection""" + global cluster, session + + if 'username' in self.cluster_options or 'password' in self.cluster_options: + raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") + + if self.lazy_connect: + return + + if 'cloud' in self.cluster_options: + if self.hosts: + log.warning("Ignoring hosts %s because a cloud config was provided.", self.hosts) + self.cluster = Cluster(**self.cluster_options) + else: + self.cluster = Cluster(self.hosts, **self.cluster_options) + + try: + self.session = self.cluster.connect() + log.debug(format_log_context("connection initialized with internally created session", connection=self.name)) + except NoHostAvailable: + if self.retry_connect: + log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name)) + self.lazy_connect = True + raise + + if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self: + cluster = _connections[DEFAULT_CONNECTION].cluster + session = _connections[DEFAULT_CONNECTION].session + + self.setup_session() + + def setup_session(self): + if self.cluster._config_mode == _ConfigMode.PROFILES: + self.cluster.profile_manager.default.row_factory = dict_factory + if self.consistency is not None: + self.cluster.profile_manager.default.consistency_level = self.consistency + else: + self.session.row_factory = dict_factory + if self.consistency is not None: + self.session.default_consistency_level = self.consistency + enc = self.session.encoder + enc.mapping[tuple] = enc.cql_encode_tuple + _register_known_types(self.session.cluster) + + def handle_lazy_connect(self): + + # if lazy_connect is False, it means the cluster is set up and ready + # No need to acquire the lock + if not self.lazy_connect: + return + + with self.lazy_connect_lock: + # lazy_connect might have been set to False by another thread while waiting the lock + # In this case, do nothing. + if self.lazy_connect: + log.debug(format_log_context("Lazy connect enabled", connection=self.name)) + self.lazy_connect = False + self.setup() + + +def register_connection(name, hosts=None, consistency=None, lazy_connect=False, + retry_connect=False, cluster_options=None, default=False, + session=None): """ - Configures the global mapper connection to localhost, using the driver defaults - (except for row_factory) + Add a connection to the connection registry. ``hosts`` and ``session`` are + mutually exclusive, and ``consistency``, ``lazy_connect``, + ``retry_connect``, and ``cluster_options`` only work with ``hosts``. Using + ``hosts`` will create a new :class:`cassandra.cluster.Cluster` and + :class:`cassandra.cluster.Session`. + + :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`). + :param int consistency: The default :class:`~.ConsistencyLevel` for the + registered connection's new session. Default is the same as + :attr:`.Session.default_consistency_level`. For use with ``hosts`` only; + will fail when used with ``session``. + :param bool lazy_connect: True if should not connect until first use. For + use with ``hosts`` only; will fail when used with ``session``. + :param bool retry_connect: True if we should retry to connect even if there + was a connection failure initially. For use with ``hosts`` only; will + fail when used with ``session``. + :param dict cluster_options: A dict of options to be used as keyword + arguments to :class:`cassandra.cluster.Cluster`. For use with ``hosts`` + only; will fail when used with ``session``. + :param bool default: If True, set the new connection as the cqlengine + default + :param Session session: A :class:`cassandra.cluster.Session` to be used in + the created connection. """ + + if name in _connections: + log.warning("Registering connection '{0}' when it already exists.".format(name)) + + if session is not None: + invalid_config_args = (hosts is not None or + consistency is not None or + lazy_connect is not False or + retry_connect is not False or + cluster_options is not None) + if invalid_config_args: + raise CQLEngineException( + "Session configuration arguments and 'session' argument are mutually exclusive" + ) + conn = Connection.from_session(name, session=session) + else: # use hosts argument + conn = Connection( + name, hosts=hosts, + consistency=consistency, lazy_connect=lazy_connect, + retry_connect=retry_connect, cluster_options=cluster_options + ) + conn.setup() + + _connections[name] = conn + + if default: + set_default_connection(name) + + return conn + + +def unregister_connection(name): global cluster, session - if session: - log.warning("configuring new connection for cqlengine when one was already set") + if name not in _connections: + return + + if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]: + del _connections[DEFAULT_CONNECTION] + cluster = None + session = None + + conn = _connections[name] + if conn.cluster: + conn.cluster.shutdown() + del _connections[name] + log.debug("Connection '{0}' has been removed from the registry.".format(name)) + + +def set_default_connection(name): + global cluster, session + + if name not in _connections: + raise CQLEngineException("Connection '{0}' doesn't exist.".format(name)) + + log.debug("Connection '{0}' has been set as default.".format(name)) + _connections[DEFAULT_CONNECTION] = _connections[name] + cluster = _connections[name].cluster + session = _connections[name].session + + +def get_connection(name=None): + + if not name: + name = DEFAULT_CONNECTION + + if name not in _connections: + raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name)) - cluster = Cluster() - session = cluster.connect() + conn = _connections[name] + conn.handle_lazy_connect() - _setup_session(session) + return conn + + +def default(): + """ + Configures the default connection to localhost, using the driver defaults + (except for row_factory) + """ + + try: + conn = get_connection() + if conn.session: + log.warning("configuring new default connection for cqlengine when one was already set") + except: + pass + + register_connection('default', hosts=None, default=True) log.debug("cqlengine connection initialized with default session to localhost") def set_session(s): """ - Configures the global mapper connection with a preexisting :class:`cassandra.cluster.Session` + Configures the default connection with a preexisting :class:`cassandra.cluster.Session` Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``. This may be relaxed in the future """ - global cluster, session - - if session: - log.warning("configuring new connection for cqlengine when one was already set") - if s.row_factory is not dict_factory: - raise CQLEngineException("Failed to initialize: 'Session.row_factory' must be 'dict_factory'.") - session = s - cluster = s.cluster + try: + conn = get_connection() + except CQLEngineException: + # no default connection set; initialize one + register_connection('default', session=s, default=True) + conn = get_connection() + else: + if conn.session: + log.warning("configuring new default session for cqlengine when one was already set") + + if not any([ + s.cluster.profile_manager.default.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.PROFILES, _ConfigMode.UNCOMMITTED], + s.row_factory is dict_factory and s.cluster._config_mode in [_ConfigMode.LEGACY, _ConfigMode.UNCOMMITTED], + ]): + raise CQLEngineException("Failed to initialize: row_factory must be 'dict_factory'") + + conn.session = s + conn.cluster = s.cluster # Set default keyspace from given session's keyspace - if session.keyspace: + if conn.session.keyspace: from cassandra.cqlengine import models - models.DEFAULT_KEYSPACE = session.keyspace + models.DEFAULT_KEYSPACE = conn.session.keyspace - _setup_session(session) + conn.setup_session() - log.debug("cqlengine connection initialized with %s", s) + log.debug("cqlengine default connection initialized with %s", s) +# TODO next major: if a cloud config is specified in kwargs, hosts will be ignored. +# This function should be refactored to reflect this change. PYTHON-1265 def setup( hosts, default_keyspace, @@ -99,116 +318,64 @@ def setup( retry_connect=False, **kwargs): """ - Setup a the driver connection used by the mapper + Set up the driver connection used by the mapper :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`) :param str default_keyspace: The default keyspace to use :param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level` :param bool lazy_connect: True if should not connect until first use :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially - :param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` + :param kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` """ - global cluster, session, lazy_connect_args - - if 'username' in kwargs or 'password' in kwargs: - raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") from cassandra.cqlengine import models models.DEFAULT_KEYSPACE = default_keyspace - if lazy_connect: - kwargs['default_keyspace'] = default_keyspace - kwargs['consistency'] = consistency - kwargs['lazy_connect'] = False - kwargs['retry_connect'] = retry_connect - lazy_connect_args = (hosts, kwargs) - return + register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect, + retry_connect=retry_connect, cluster_options=kwargs, default=True) - cluster = Cluster(hosts, **kwargs) - try: - session = cluster.connect() - log.debug("cqlengine connection initialized with internally created session") - except NoHostAvailable: - if retry_connect: - log.warning("connect failed, setting up for re-attempt on first use") - kwargs['default_keyspace'] = default_keyspace - kwargs['consistency'] = consistency - kwargs['lazy_connect'] = False - kwargs['retry_connect'] = retry_connect - lazy_connect_args = (hosts, kwargs) - raise - if consistency is not None: - session.default_consistency_level = consistency - _setup_session(session) +def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None): + conn = get_connection(connection) -def _setup_session(session): - session.row_factory = dict_factory - enc = session.encoder - enc.mapping[tuple] = enc.cql_encode_tuple - _register_known_types(session.cluster) - - -def execute(query, params=None, consistency_level=None, timeout=NOT_SET): - - handle_lazy_connect() - - if not session: + if not conn.session: raise CQLEngineException("It is required to setup() cqlengine before executing queries") - if isinstance(query, Statement): - pass - + if isinstance(query, SimpleStatement): + pass # elif isinstance(query, BaseCQLStatement): params = query.get_context() query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size) - - elif isinstance(query, six.string_types): + elif isinstance(query, str): query = SimpleStatement(query, consistency_level=consistency_level) + log.debug(format_log_context('Query: {}, Params: {}'.format(query.query_string, params), connection=connection)) - log.debug(query.query_string) - - params = params or {} - result = session.execute(query, params, timeout=timeout) + result = conn.session.execute(query, params, timeout=timeout) return result -def get_session(): - handle_lazy_connect() - return session +def get_session(connection=None): + conn = get_connection(connection) + return conn.session -def get_cluster(): - handle_lazy_connect() - if not cluster: +def get_cluster(connection=None): + conn = get_connection(connection) + if not conn.cluster: raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__) - return cluster + return conn.cluster -def handle_lazy_connect(): - global lazy_connect_args - - # if lazy_connect_args is None, it means the cluster is setup and ready - # No need to acquire the lock - if not lazy_connect_args: - return - - with lazy_connect_lock: - # lazy_connect_args might have been set to None by another thread while waiting the lock - # In this case, do nothing. - if lazy_connect_args: - log.debug("lazy connect") - hosts, kwargs = lazy_connect_args - setup(hosts, **kwargs) - lazy_connect_args = None - - -def register_udt(keyspace, type_name, klass): +def register_udt(keyspace, type_name, klass, connection=None): udt_by_keyspace[keyspace][type_name] = klass - global cluster + try: + cluster = get_cluster(connection) + except CQLEngineException: + cluster = None + if cluster: try: cluster.register_user_type(keyspace, type_name, klass) diff --git a/cassandra/cqlengine/functions.py b/cassandra/cqlengine/functions.py index ccfe9de93a..69bdc3feb4 100644 --- a/cassandra/cqlengine/functions.py +++ b/cassandra/cqlengine/functions.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,21 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division from datetime import datetime from cassandra.cqlengine import UnicodeMixin, ValidationError -import sys - -if sys.version_info >= (2, 7): - def get_total_seconds(td): - return td.total_seconds() -else: - def get_total_seconds(td): - # integer division used here to emulate built-in total_seconds - return ((86400 * td.days + td.seconds) * 10 ** 6 + td.microseconds) / 10 ** 6 - +def get_total_seconds(td): + return td.total_seconds() class QueryValue(UnicodeMixin): """ @@ -86,7 +79,7 @@ class MinTimeUUID(TimeUUIDQueryFunction): """ return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp - http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ format_string = 'MinTimeUUID(%({0})s)' @@ -95,7 +88,7 @@ class MaxTimeUUID(TimeUUIDQueryFunction): """ return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp - http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ format_string = 'MaxTimeUUID(%({0})s)' @@ -104,7 +97,7 @@ class Token(BaseQueryFunction): """ compute the token for a given partition key - http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun + http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun """ def __init__(self, *values): if len(values) == 1 and isinstance(values[0], (list, tuple)): diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index a071e0b7c8..66b391b714 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +18,13 @@ import json import logging import os -import six import warnings +from itertools import product from cassandra import metadata from cassandra.cqlengine import CQLEngineException -from cassandra.cqlengine import columns -from cassandra.cqlengine.connection import execute, get_cluster +from cassandra.cqlengine import columns, query +from cassandra.cqlengine.connection import execute, get_cluster, format_log_context from cassandra.cqlengine.models import Model from cassandra.cqlengine.named import NamedTable from cassandra.cqlengine.usertype import UserType @@ -37,7 +39,24 @@ schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') -def create_keyspace_simple(name, replication_factor, durable_writes=True): +def _get_context(keyspaces, connections): + """Return all the execution contexts""" + + if keyspaces: + if not isinstance(keyspaces, (list, tuple)): + raise ValueError('keyspaces must be a list or a tuple.') + + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('connections must be a list or a tuple.') + + keyspaces = keyspaces if keyspaces else [None] + connections = connections if connections else [None] + + return product(connections, keyspaces) + + +def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None): """ Creates a keyspace with SimpleStrategy for replica placement @@ -51,12 +70,13 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True): :param str name: name of keyspace to create :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` :param bool durable_writes: Write log is bypassed if set to False + :param list connections: List of connection names """ _create_keyspace(name, durable_writes, 'SimpleStrategy', - {'replication_factor': replication_factor}) + {'replication_factor': replication_factor}, connections=connections) -def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True): +def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None): """ Creates a keyspace with NetworkTopologyStrategy for replica placement @@ -70,25 +90,37 @@ def create_keyspace_network_topology(name, dc_replication_map, durable_writes=Tr :param str name: name of keyspace to create :param dict dc_replication_map: map of dc_names: replication_factor :param bool durable_writes: Write log is bypassed if set to False + :param list connections: List of connection names """ - _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map) + _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections) -def _create_keyspace(name, durable_writes, strategy_class, strategy_options): +def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None): if not _allow_schema_modification(): return - cluster = get_cluster() + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('Connections must be a list or a tuple.') - if name not in cluster.metadata.keyspaces: - log.info("Creating keyspace %s ", name) - ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) - execute(ks_meta.as_cql_query()) + def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None): + cluster = get_cluster(connection) + + if name not in cluster.metadata.keyspaces: + log.info(format_log_context("Creating keyspace %s", connection=connection), name) + ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) + execute(ks_meta.as_cql_query(), connection=connection) + else: + log.info(format_log_context("Not creating keyspace %s because it already exists", connection=connection), name) + + if connections: + for connection in connections: + __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection) else: - log.info("Not creating keyspace %s because it already exists", name) + __create_keyspace(name, durable_writes, strategy_class, strategy_options) -def drop_keyspace(name): +def drop_keyspace(name, connections=None): """ Drops a keyspace, if it exists. @@ -98,32 +130,48 @@ def drop_keyspace(name): Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** :param str name: name of keyspace to drop + :param list connections: List of connection names """ if not _allow_schema_modification(): return - cluster = get_cluster() - if name in cluster.metadata.keyspaces: - execute("DROP KEYSPACE {0}".format(metadata.protect_name(name))) + if connections: + if not isinstance(connections, (list, tuple)): + raise ValueError('Connections must be a list or a tuple.') + def _drop_keyspace(name, connection=None): + cluster = get_cluster(connection) + if name in cluster.metadata.keyspaces: + execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection) + + if connections: + for connection in connections: + _drop_keyspace(name, connection) + else: + _drop_keyspace(name) def _get_index_name_by_column(table, column_name): """ Find the index name for a given table and column. """ - for _, index_metadata in six.iteritems(table.indexes): + protected_name = metadata.protect_name(column_name) + possible_index_values = [protected_name, "values(%s)" % protected_name] + for index_metadata in table.indexes.values(): options = dict(index_metadata.index_options) - possible_index_values = [column_name, "values(%s)" % column_name] - if 'target' in options and options['target'] in possible_index_values: + if options.get('target') in possible_index_values: return index_metadata.name - return None - -def sync_table(model): +def sync_table(model, keyspaces=None, connections=None): """ Inspects the model and creates / updates the corresponding table and columns. + If `keyspaces` is specified, the table will be synched for all specified keyspaces. + Note that the `Model.__keyspace__` is ignored in that case. + + If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. + If not specified, it will try to get the connection from the Model. + Any User Defined Types used in the table are implicitly synchronized. This function can only add fields that are not part of the primary key. @@ -136,6 +184,14 @@ def sync_table(model): *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ + + context = _get_context(keyspaces, connections) + for connection, keyspace in context: + with query.ContextQuery(model, keyspace=keyspace) as m: + _sync_table(m, connection=connection) + + +def _sync_table(model, connection=None): if not _allow_schema_modification(): return @@ -149,13 +205,15 @@ def sync_table(model): raw_cf_name = model._raw_column_family_name() ks_name = model._get_keyspace() + connection = connection or model._get_connection() - cluster = get_cluster() + cluster = get_cluster(connection) try: keyspace = cluster.metadata.keyspaces[ks_name] except KeyError: - raise CQLEngineException("Keyspace '{0}' for model {1} does not exist.".format(ks_name, model)) + msg = format_log_context("Keyspace '{0}' for model {1} does not exist.", connection=connection) + raise CQLEngineException(msg.format(ks_name, model)) tables = keyspace.tables @@ -164,21 +222,21 @@ def sync_table(model): udts = [] columns.resolve_udts(col, udts) for udt in [u for u in udts if u not in syncd_types]: - _sync_type(ks_name, udt, syncd_types) + _sync_type(ks_name, udt, syncd_types, connection=connection) if raw_cf_name not in tables: - log.debug("sync_table creating new table %s", cf_name) + log.debug(format_log_context("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name) qs = _get_create_table(model) try: - execute(qs) + execute(qs, connection=connection) except CQLEngineException as ex: # 1.2 doesn't return cf names, so we have to examine the exception # and ignore if it says the column family already exists - if "Cannot add already existing column family" not in unicode(ex): + if "Cannot add already existing column family" not in str(ex): raise else: - log.debug("sync_table checking existing table %s", cf_name) + log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name) table_meta = tables[raw_cf_name] _validate_pk(model, table_meta) @@ -192,24 +250,27 @@ def sync_table(model): if db_name in table_columns: col_meta = table_columns[db_name] if col_meta.cql_type != col.db_type: - msg = 'Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' \ - ' Model should be updated.'.format(cf_name, db_name, col_meta.cql_type, col.db_type) + msg = format_log_context('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' + ' Model should be updated.', keyspace=ks_name, connection=connection) + msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type) warnings.warn(msg) log.warning(msg) continue - if col.primary_key or col.primary_key: - raise CQLEngineException("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}".format(model_name, db_name, cf_name)) + if col.primary_key: + msg = format_log_context("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection) + raise CQLEngineException(msg.format(model_name, db_name, cf_name)) query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def()) - execute(query) + execute(query, connection=connection) db_fields_not_in_model = model_fields.symmetric_difference(table_columns) if db_fields_not_in_model: - log.info("Table {0} has fields not referenced by model: {1}".format(cf_name, db_fields_not_in_model)) + msg = format_log_context("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection) + log.info(msg.format(cf_name, db_fields_not_in_model)) - _update_options(model) + _update_options(model, connection=connection) table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] @@ -225,7 +286,7 @@ def sync_table(model): qs += ['ON {0}'.format(cf_name)] qs += ['("{0}")'.format(column.db_field_name)] qs = ' '.join(qs) - execute(qs) + execute(qs, connection=connection) def _validate_pk(model, table_meta): @@ -244,7 +305,7 @@ def _pk_string(partition, clustering): _pk_string(meta_partition, meta_clustering))) -def sync_type(ks_name, type_model): +def sync_type(ks_name, type_model, connection=None): """ Inspects the type_model and creates / updates the corresponding type. @@ -262,33 +323,33 @@ def sync_type(ks_name, type_model): if not issubclass(type_model, UserType): raise CQLEngineException("Types must be derived from base UserType.") - _sync_type(ks_name, type_model) + _sync_type(ks_name, type_model, connection=connection) -def _sync_type(ks_name, type_model, omit_subtypes=None): +def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): syncd_sub_types = omit_subtypes or set() for field in type_model._fields.values(): udts = [] columns.resolve_udts(field, udts) for udt in [u for u in udts if u not in syncd_sub_types]: - _sync_type(ks_name, udt, syncd_sub_types) + _sync_type(ks_name, udt, syncd_sub_types, connection=connection) syncd_sub_types.add(udt) type_name = type_model.type_name() type_name_qualified = "%s.%s" % (ks_name, type_name) - cluster = get_cluster() + cluster = get_cluster(connection) keyspace = cluster.metadata.keyspaces[ks_name] defined_types = keyspace.user_types if type_name not in defined_types: - log.debug("sync_type creating new type %s", type_name_qualified) + log.debug(format_log_context("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified) cql = get_create_type(type_model, ks_name) - execute(cql) + execute(cql, connection=connection) cluster.refresh_user_type_metadata(ks_name, type_name) - type_model.register_for_keyspace(ks_name) + type_model.register_for_keyspace(ks_name, connection=connection) else: type_meta = defined_types[type_name] defined_fields = type_meta.field_names @@ -296,24 +357,26 @@ def _sync_type(ks_name, type_model, omit_subtypes=None): for field in type_model._fields.values(): model_fields.add(field.db_field_name) if field.db_field_name not in defined_fields: - execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def())) + execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection) else: field_type = type_meta.field_types[defined_fields.index(field.db_field_name)] if field_type != field.db_type: - msg = 'Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' \ - ' UserType should be updated.'.format(type_name_qualified, field.db_field_name, field_type, field.db_type) + msg = format_log_context('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' + ' UserType should be updated.', keyspace=ks_name, connection=connection) + msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type) warnings.warn(msg) log.warning(msg) - type_model.register_for_keyspace(ks_name) + type_model.register_for_keyspace(ks_name, connection=connection) if len(defined_fields) == len(model_fields): - log.info("Type %s did not require synchronization", type_name_qualified) + log.info(format_log_context("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified) return db_fields_not_in_model = model_fields.symmetric_difference(defined_fields) if db_fields_not_in_model: - log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model) + msg = format_log_context("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection) + log.info(msg, type_name_qualified, db_fields_not_in_model) def get_create_type(type_model, keyspace): @@ -362,9 +425,9 @@ def add_column(col): return ' '.join(query_strings) -def _get_table_metadata(model): +def _get_table_metadata(model, connection=None): # returns the table as provided by the native driver for a given model - cluster = get_cluster() + cluster = get_cluster(connection) ks = model._get_keyspace() table = model._raw_column_family_name() table = cluster.metadata.keyspaces[ks].tables[table] @@ -386,19 +449,22 @@ def _options_map_from_strings(option_strings): return options -def _update_options(model): +def _update_options(model, connection=None): """Updates the table options for the given model if necessary. :param model: The model to update. + :param connection: Name of the connection to use :return: `True`, if the options were modified in Cassandra, `False` otherwise. :rtype: bool """ - log.debug("Checking %s for option differences", model) + ks_name = model._get_keyspace() + msg = format_log_context("Checking %s for option differences", keyspace=ks_name, connection=connection) + log.debug(msg, model) model_options = model.__options__ or {} - table_meta = _get_table_metadata(model) + table_meta = _get_table_metadata(model, connection=connection) # go to CQL string first to normalize meta from different versions existing_option_strings = set(table_meta._make_option_strings(table_meta.options)) existing_options = _options_map_from_strings(existing_option_strings) @@ -410,8 +476,9 @@ def _update_options(model): try: existing_value = existing_options[name] except KeyError: - raise KeyError("Invalid table option: '%s'; known options: %s" % (name, existing_options.keys())) - if isinstance(existing_value, six.string_types): + msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection) + raise KeyError(msg % (name, existing_options.keys())) + if isinstance(existing_value, str): if value != existing_value: update_options[name] = value else: @@ -426,33 +493,49 @@ def _update_options(model): if update_options: options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options)) query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options) - execute(query) + execute(query, connection=connection) return True return False -def drop_table(model): +def drop_table(model, keyspaces=None, connections=None): """ Drops the table indicated by the model, if it exists. + If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. + + If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. + If not specified, it will try to get the connection from the Model. + + **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ + + context = _get_context(keyspaces, connections) + for connection, keyspace in context: + with query.ContextQuery(model, keyspace=keyspace) as m: + _drop_table(m, connection=connection) + + +def _drop_table(model, connection=None): if not _allow_schema_modification(): return - # don't try to delete non existant tables - meta = get_cluster().metadata + connection = connection or model._get_connection() + + # don't try to delete non existent tables + meta = get_cluster(connection).metadata ks_name = model._get_keyspace() raw_cf_name = model._raw_column_family_name() try: meta.keyspaces[ks_name].tables[raw_cf_name] - execute('DROP TABLE {0};'.format(model.column_family_name())) + execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection) except KeyError: pass diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index 4e71ccaaa7..f0f5a207ec 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +16,6 @@ import logging import re -import six from warnings import warn from cassandra.cqlengine import CQLEngineException, ValidationError @@ -29,6 +30,17 @@ log = logging.getLogger(__name__) +def _clone_model_class(model, attrs): + new_type = type(model.__name__, (model,), attrs) + try: + new_type.__abstract__ = model.__abstract__ + new_type.__discriminator_value__ = model.__discriminator_value__ + new_type.__default_ttl__ = model.__default_ttl__ + except AttributeError: + pass + return new_type + + class ModelException(CQLEngineException): pass @@ -173,7 +185,7 @@ def __call__(self, *args, **kwargs): class IfNotExistsDescriptor(object): """ - return a query set descriptor with a if_not_exists flag specified + return a query set descriptor with an if_not_exists flag specified """ def __get__(self, instance, model): if instance: @@ -191,7 +203,7 @@ def __call__(self, *args, **kwargs): class IfExistsDescriptor(object): """ - return a query set descriptor with a if_exists flag specified + return a query set descriptor with an if_exists flag specified """ def __get__(self, instance, model): if instance: @@ -231,6 +243,25 @@ def __call__(self, *args, **kwargs): raise NotImplementedError +class UsingDescriptor(object): + """ + return a query set descriptor with a connection context specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def using_setter(connection=None): + if connection: + instance._connection = connection + return instance + return using_setter + + return model.objects.using + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + class ColumnQueryEvaluator(query.AbstractQueryableColumn): """ Wraps a column and allows it to be used in comparator @@ -323,6 +354,8 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): if_exists = IfExistsDescriptor() + using = UsingDescriptor() + # _len is lazily created by __len__ __table_name__ = None @@ -331,10 +364,14 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): __keyspace__ = None + __connection__ = None + __discriminator_value__ = None __options__ = None + __compute_routing_key__ = True + # the queryset class used for this class __queryset__ = query.ModelQuerySet __dmlquery__ = query.DMLQuery @@ -349,17 +386,24 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): _table_name = None # used internally to cache a derived table name + _connection = None + def __init__(self, **values): - self._ttl = self.__default_ttl__ + self._ttl = None self._timestamp = None self._conditional = None self._batch = None self._timeout = connection.NOT_SET self._is_persisted = False + self._connection = None self._values = {} for name, column in self._columns.items(): - value = values.get(name) + # Set default values on instantiation. Thanks to this, we don't have + # to wait any longer for a call to validate() to have CQLengine set + # default columns values. + column_default = column.get_default() if column.has_default else None + value = values.get(name, column_default) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) @@ -379,6 +423,10 @@ def __str__(self): return '{0} <{1}>'.format(self.__class__.__name__, ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + @classmethod + def _routing_key_from_values(cls, pk_values, protocol_version): + return cls._key_serializer(pk_values, protocol_version) + @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: @@ -437,12 +485,14 @@ def _construct_instance(cls, values): klass = cls instance = klass(**values) - instance._set_persisted() + instance._set_persisted(force=True) return instance - def _set_persisted(self): - for v in self._values.values(): + def _set_persisted(self, force=False): + # ensure we don't modify to any values not affected by the last save/update + for v in [v for v in self._values.values() if v.changed or force]: v.reset_previous_value() + v.explicit = False self._is_persisted = True def _can_update(self): @@ -518,6 +568,7 @@ def _raw_column_family_name(cls): if not cls._table_name: if cls.__table_name__: if cls.__table_name_case_sensitive__: + warn("Model __table_name_case_sensitive__ will be removed in 4.0.", PendingDeprecationWarning) cls._table_name = cls.__table_name__ else: table_name = cls.__table_name__.lower() @@ -541,6 +592,10 @@ def _raw_column_family_name(cls): return cls._table_name + def _set_column_value(self, name, value): + """Function to change a column value without changing the value manager states""" + self._values[name].value = value # internal assignement, skip the main setter + def validate(self): """ Cleans and validates the field values @@ -550,7 +605,7 @@ def validate(self): if v is None and not self._values[name].explicit and col.has_default: v = col.get_default() val = col.validate(v) - setattr(self, name, val) + self._set_column_value(name, val) # Let an instance be used like a dict of its columns keys/values def __iter__(self): @@ -560,7 +615,7 @@ def __iter__(self): def __getitem__(self, key): """ Returns column's value. """ - if not isinstance(key, six.string_types): + if not isinstance(key, str): raise TypeError if key not in self._columns.keys(): raise KeyError @@ -568,7 +623,7 @@ def __getitem__(self, key): def __setitem__(self, key, val): """ Sets a column's value. """ - if not isinstance(key, six.string_types): + if not isinstance(key, str): raise TypeError if key not in self._columns.keys(): raise KeyError @@ -608,7 +663,8 @@ def create(cls, **kwargs): """ Create an instance of this model in the database. - Takes the model column values as keyword arguments. + Takes the model column values as keyword arguments. Setting a value to + `None` is equivalent to running a CQL `DELETE` on that column. Returns the instance. """ @@ -685,7 +741,6 @@ def save(self): self._set_persisted() - self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -695,23 +750,30 @@ def update(self, **values): Performs an update on the model instance. You can pass in values to set on the model for updating, or you can call without values to execute an update against any modified fields. If no fields on the model have been modified since loading, no query will be - performed. Model validation is performed normally. + performed. Model validation is performed normally. Setting a value to `None` is + equivalent to running a CQL `DELETE` on that column. It is possible to do a blind update, that is, to update a field without having first selected the object out of the database. See :ref:`Blind Updates ` """ - for k, v in values.items(): - col = self._columns.get(k) + for column_id, v in values.items(): + col = self._columns.get(column_id) # check for nonexistant columns if col is None: - raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.__class__.__name__, k)) + raise ValidationError( + "{0}.{1} has no column named: {2}".format( + self.__module__, self.__class__.__name__, column_id)) # check for primary key update attempts if col.is_primary_key: - raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(k, self.__module__, self.__class__.__name__)) + current_value = getattr(self, column_id) + if v != current_value: + raise ValidationError( + "Cannot apply update to primary key '{0}' for {1}.{2}".format( + column_id, self.__module__, self.__class__.__name__)) - setattr(self, k, v) + setattr(self, column_id, v) # handle polymorphic models if self._is_polymorphic: @@ -732,7 +794,6 @@ def update(self, **values): self._set_persisted() - self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -761,11 +822,22 @@ def _class_batch(cls, batch): def _inst_batch(self, batch): assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' + if self._connection: + raise CQLEngineException("Cannot specify a connection on model in batch mode.") self._batch = batch return self batch = hybrid_classmethod(_class_batch, _inst_batch) + @classmethod + def _class_get_connection(cls): + return cls.__connection__ + + def _inst_get_connection(self): + return self._connection or self.__connection__ + + _get_connection = hybrid_classmethod(_class_get_connection, _inst_get_connection) + class ModelMetaClass(type): @@ -788,17 +860,10 @@ def __new__(cls, name, bases, attrs): # short circuit __discriminator_value__ inheritance attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') + # TODO __default__ttl__ should be removed in the next major release options = attrs.get('__options__') or {} attrs['__default_ttl__'] = options.get('default_time_to_live') - def _transform_column(col_name, col_obj): - column_dict[col_name] = col_obj - if col_obj.primary_key: - primary_keys[col_name] = col_obj - col_obj.set_column_name(col_name) - # set properties - attrs[col_name] = ColumnDescriptor(col_obj) - column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) @@ -843,6 +908,15 @@ def _get_polymorphic_base(bases): has_partition_keys = any(v.partition_key for (k, v) in column_definitions) + def _transform_column(col_name, col_obj): + column_dict[col_name] = col_obj + if col_obj.primary_key: + primary_keys[col_name] = col_obj + col_obj.set_column_name(col_name) + # set properties + attrs[col_name] = ColumnDescriptor(col_obj) + + partition_key_index = 0 # transform column definitions for k, v in column_definitions: # don't allow a column with the same name as a built-in attribute or method @@ -858,11 +932,29 @@ def _get_polymorphic_base(bases): if not has_partition_keys and v.primary_key: v.partition_key = True has_partition_keys = True + if v.partition_key: + v._partition_key_index = partition_key_index + partition_key_index += 1 + + overriding = column_dict.get(k) + if overriding: + v.position = overriding.position + v.partition_key = overriding.partition_key + v._partition_key_index = overriding._partition_key_index _transform_column(k, v) partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + if attrs.get('__compute_routing_key__', True): + key_cols = [c for c in partition_keys.values()] + partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) + key_cql_types = [c.cql_type for c in key_cols] + key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + else: + partition_key_index = {} + key_serializer = staticmethod(lambda parts, proto_version: None) + # setup partition key shortcut if len(partition_keys) == 0: if not is_abstract: @@ -906,6 +998,8 @@ def _get_polymorphic_base(bases): attrs['_dynamic_columns'] = {} attrs['_partition_keys'] = partition_keys + attrs['_partition_key_index'] = partition_key_index + attrs['_key_serializer'] = key_serializer attrs['_clustering_keys'] = clustering_keys attrs['_has_counter'] = len(counter_columns) > 0 @@ -949,8 +1043,7 @@ def _get_polymorphic_base(bases): return klass -@six.add_metaclass(ModelMetaClass) -class Model(BaseModel): +class Model(BaseModel, metaclass=ModelMetaClass): __abstract__ = True """ *Optional.* Indicates that this model is only intended to be used as a base class for other models. @@ -972,6 +1065,11 @@ class Model(BaseModel): Sets the name of the keyspace used by this model. """ + __connection__ = None + """ + Sets the name of the default connection used by this model. + """ + __options__ = None """ *Optional* Table options applied with this model @@ -983,3 +1081,8 @@ class Model(BaseModel): """ *Optional* Specifies a value for the discriminator column when using model inheritance. """ + + __compute_routing_key__ = True + """ + *Optional* Setting False disables computing the routing key for TokenAwareRouting + """ diff --git a/cassandra/cqlengine/named.py b/cassandra/cqlengine/named.py index 90a0d8fdf2..219155818c 100644 --- a/cassandra/cqlengine/named.py +++ b/cassandra/cqlengine/named.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,6 +19,7 @@ from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.columns import Column from cassandra.cqlengine.connection import get_cluster +from cassandra.cqlengine.models import UsingDescriptor, BaseModel from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned @@ -84,6 +87,15 @@ class NamedTable(object): __partition_keys = None + _partition_key_index = None + + __connection__ = None + _connection = None + + using = UsingDescriptor() + + _get_connection = BaseModel._get_connection + class DoesNotExist(_DoesNotExist): pass @@ -93,6 +105,7 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): def __init__(self, keyspace, name): self.keyspace = keyspace self.name = name + self._connection = None @property def _partition_keys(self): @@ -102,7 +115,7 @@ def _partition_keys(self): def _get_partition_keys(self): try: - table_meta = get_cluster().metadata.keyspaces[self.keyspace].tables[self.name] + table_meta = get_cluster(self._get_connection()).metadata.keyspaces[self.keyspace].tables[self.name] self.__partition_keys = OrderedDict((pk.name, Column(primary_key=True, partition_key=True, db_field=pk.name)) for pk in table_meta.partition_key) except Exception as e: raise CQLEngineException("Failed inspecting partition keys for {0}." diff --git a/cassandra/cqlengine/operators.py b/cassandra/cqlengine/operators.py index 0aa29d94ae..a9e7db2545 100644 --- a/cassandra/cqlengine/operators.py +++ b/cassandra/cqlengine/operators.py @@ -1,17 +1,18 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import six from cassandra.cqlengine import UnicodeMixin @@ -44,8 +45,7 @@ def __init__(cls, name, bases, dct): super(OpMapMeta, cls).__init__(name, bases, dct) -@six.add_metaclass(OpMapMeta) -class BaseWhereOperator(BaseQueryOperator): +class BaseWhereOperator(BaseQueryOperator, metaclass=OpMapMeta): """ base operator used for where clauses """ @classmethod def get_operator(cls, symbol): @@ -60,6 +60,11 @@ class EqualsOperator(BaseWhereOperator): cql_symbol = '=' +class NotEqualsOperator(BaseWhereOperator): + symbol = 'NE' + cql_symbol = '!=' + + class InOperator(EqualsOperator): symbol = 'IN' cql_symbol = 'IN' @@ -88,3 +93,13 @@ class LessThanOrEqualOperator(BaseWhereOperator): class ContainsOperator(EqualsOperator): symbol = "CONTAINS" cql_symbol = 'CONTAINS' + + +class LikeOperator(EqualsOperator): + symbol = "LIKE" + cql_symbol = 'LIKE' + + +class IsNotNullOperator(EqualsOperator): + symbol = "IS NOT NULL" + cql_symbol = 'IS NOT NULL' diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index f9e1a75d42..329bc7fade 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,20 +18,18 @@ from datetime import datetime, timedelta from functools import partial import time -import six from warnings import warn +from cassandra.query import SimpleStatement, BatchType as CBatchType, BatchStatement from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin -from cassandra.cqlengine import connection +from cassandra.cqlengine import connection as conn from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator, LessThanOperator, LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, - UpdateStatement, AssignmentClause, InsertStatement, - BaseCQLStatement, MapUpdateClause, MapDeleteClause, - ListUpdateClause, SetUpdateClause, CounterUpdateClause, - ConditionalClause) + UpdateStatement, InsertStatement, + BaseCQLStatement, MapDeleteClause, ConditionalClause) class QueryException(CQLEngineException): @@ -39,6 +39,7 @@ class QueryException(CQLEngineException): class IfNotExistsWithCounterColumn(CQLEngineException): pass + class IfExistsWithCounterColumn(CQLEngineException): pass @@ -67,14 +68,16 @@ class MultipleObjectsReturned(QueryException): def check_applied(result): """ - Raises LWTException if it looks like a failed LWT request. + Raises LWTException if it looks like a failed LWT request. A LWTException + won't be raised in the special case in which there are several failed LWT + in a :class:`~cqlengine.query.BatchQuery`. """ try: applied = result.was_applied except Exception: applied = True # result was not LWT form if not applied: - raise LWTException(result[0]) + raise LWTException(result.one()) class AbstractQueryableColumn(UnicodeMixin): @@ -101,29 +104,29 @@ def in_(self, item): used where you'd typically want to use python's `in` operator """ - return WhereClause(six.text_type(self), InOperator(), item) + return WhereClause(str(self), InOperator(), item) def contains_(self, item): """ Returns a CONTAINS operator """ - return WhereClause(six.text_type(self), ContainsOperator(), item) + return WhereClause(str(self), ContainsOperator(), item) def __eq__(self, other): - return WhereClause(six.text_type(self), EqualsOperator(), self._to_database(other)) + return WhereClause(str(self), EqualsOperator(), self._to_database(other)) def __gt__(self, other): - return WhereClause(six.text_type(self), GreaterThanOperator(), self._to_database(other)) + return WhereClause(str(self), GreaterThanOperator(), self._to_database(other)) def __ge__(self, other): - return WhereClause(six.text_type(self), GreaterThanOrEqualOperator(), self._to_database(other)) + return WhereClause(str(self), GreaterThanOrEqualOperator(), self._to_database(other)) def __lt__(self, other): - return WhereClause(six.text_type(self), LessThanOperator(), self._to_database(other)) + return WhereClause(str(self), LessThanOperator(), self._to_database(other)) def __le__(self, other): - return WhereClause(six.text_type(self), LessThanOrEqualOperator(), self._to_database(other)) + return WhereClause(str(self), LessThanOrEqualOperator(), self._to_database(other)) class BatchType(object): @@ -135,17 +138,23 @@ class BatchQuery(object): """ Handles the batching of queries - http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH + http://docs.datastax.com/en/cql/3.0/cql/cql_reference/batch_r.html + + See :doc:`/cqlengine/batches` for more details. """ warn_multiple_exec = True _consistency = None + _connection = None + _connection_explicit = False + + def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, - timeout=connection.NOT_SET): + timeout=conn.NOT_SET, connection=None): """ :param batch_type: (optional) One of batch type values available through BatchType enum - :type batch_type: str or None + :type batch_type: BatchType, str or None :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied to the batch conditional. :type timestamp: datetime or timedelta or None @@ -159,6 +168,7 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback to default session timeout :type timeout: float or None + :param str connection: Connection name to use for the batch execution """ self.queries = [] self.batch_type = batch_type @@ -171,6 +181,9 @@ def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on self._callbacks = [] self._executed = False self._context_entered = False + self._connection = connection + if connection: + self._connection_explicit = True def add_query(self, query): if not isinstance(query, BaseCQLStatement): @@ -194,8 +207,8 @@ def add_callback(self, fn, *args, **kwargs): :param fn: Callable object :type fn: callable - :param *args: Positional arguments to be passed to the callback at the time of execution - :param **kwargs: Named arguments to be passed to the callback at the time of execution + :param args: Positional arguments to be passed to the callback at the time of execution + :param kwargs: Named arguments to be passed to the callback at the time of execution """ if not callable(fn): raise ValueError("Value for argument 'fn' is {0} and is not a callable object.".format(type(fn))) @@ -215,10 +228,11 @@ def execute(self): self._execute_callbacks() return - opener = 'BEGIN ' + (self.batch_type + ' ' if self.batch_type else '') + ' BATCH' + batch_type = None if self.batch_type is CBatchType.LOGGED else self.batch_type + opener = 'BEGIN ' + (str(batch_type) + ' ' if batch_type else '') + ' BATCH' if self.timestamp: - if isinstance(self.timestamp, six.integer_types): + if isinstance(self.timestamp, int): ts = self.timestamp elif isinstance(self.timestamp, (datetime, timedelta)): ts = self.timestamp @@ -242,7 +256,7 @@ def execute(self): query_list.append('APPLY BATCH;') - tmp = connection.execute('\n'.join(query_list), parameters, self._consistency, self._timeout) + tmp = conn.execute('\n'.join(query_list), parameters, self._consistency, self._timeout, connection=self._connection) check_applied(tmp) self.queries = [] @@ -259,6 +273,71 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.execute() +class ContextQuery(object): + """ + A Context manager to allow a Model to switch context easily. Presently, the context only + specifies a keyspace for model IO. + + :param args: One or more models. A model should be a class type, not an instance. + :param kwargs: (optional) Context parameters: can be *keyspace* or *connection* + + For example: + + .. code-block:: python + + with ContextQuery(Automobile, keyspace='test2') as A: + A.objects.create(manufacturer='honda', year=2008, model='civic') + print(len(A.objects.all())) # 1 result + + with ContextQuery(Automobile, keyspace='test4') as A: + print(len(A.objects.all())) # 0 result + + # Multiple models + with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2): + print(len(A.objects.all())) + print(len(A2.objects.all())) + + """ + + def __init__(self, *args, **kwargs): + from cassandra.cqlengine import models + + self.models = [] + + if len(args) < 1: + raise ValueError("No model provided.") + + keyspace = kwargs.pop('keyspace', None) + connection = kwargs.pop('connection', None) + + if kwargs: + raise ValueError("Unknown keyword argument(s): {0}".format( + ','.join(kwargs.keys()))) + + for model in args: + try: + issubclass(model, models.Model) + except TypeError: + raise ValueError("Models must be derived from base Model.") + + m = models._clone_model_class(model, {}) + + if keyspace: + m.__keyspace__ = keyspace + if connection: + m.__connection__ = connection + + self.models.append(m) + + def __enter__(self): + if len(self.models) > 1: + return tuple(self.models) + return self.models[0] + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + class AbstractQuerySet(object): def __init__(self, model): @@ -280,9 +359,15 @@ def __init__(self, model): # because explicit is better than implicit self._limit = 10000 - # see the defer and only methods + # We store the fields for which we use the Equal operator + # in a query, so we don't select it from the DB. _defer_fields + # will contain the names of the fields in the DB, not the names + # of the variables used by the mapper self._defer_fields = set() self._deferred_values = {} + + # This variable will hold the names in the database of the fields + # for which we want to query self._only_fields = [] self._values_list = False @@ -299,29 +384,31 @@ def __init__(self, model): self._count = None self._batch = None - self._ttl = getattr(model, '__default_ttl__', None) + self._ttl = None self._consistency = None self._timestamp = None self._if_not_exists = False - self._timeout = connection.NOT_SET + self._timeout = conn.NOT_SET self._if_exists = False self._fetch_size = None + self._connection = None @property def column_family_name(self): return self.model.column_family_name() - def _execute(self, q): + def _execute(self, statement): if self._batch: - return self._batch.add_query(q) + return self._batch.add_query(statement) else: - result = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + connection = self._connection or self.model._get_connection() + result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result def __unicode__(self): - return six.text_type(self._select_query()) + return str(self._select_query()) def __str__(self): return str(self.__unicode__()) @@ -332,7 +419,7 @@ def __call__(self, *args, **kwargs): def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these + if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across @@ -449,6 +536,10 @@ def __getitem__(self, s): if isinstance(s, slice): start = s.start if s.start else 0 + if start < 0 or (s.stop is not None and s.stop < 0): + warn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", + DeprecationWarning) + # calculate the amount of results that need to be loaded end = s.stop if start < 0 or s.stop is None or s.stop < 0: @@ -466,6 +557,10 @@ def __getitem__(self, s): except (ValueError, TypeError): raise TypeError('QuerySet indices must be integers') + if s < 0: + warn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", + DeprecationWarning) + # Using negative indexing is costly since we have to execute a count() if s < 0: num_results = self.count() @@ -499,6 +594,9 @@ def batch(self, batch_obj): Note: running a select query with a batch object will raise an exception """ + if self._connection: + raise CQLEngineException("Cannot specify the connection on model in batch mode.") + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): raise CQLEngineException('batch_obj must be a BatchQuery instance or None') clone = copy.deepcopy(self) @@ -507,7 +605,7 @@ def batch(self, batch_obj): def first(self): try: - return six.next(iter(self)) + return next(iter(self)) except StopIteration: return None @@ -545,7 +643,7 @@ def _parse_filter_arg(self, arg): if len(statement) == 1: return arg, None elif len(statement) == 2: - return statement[0], statement[1] + return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) @@ -560,10 +658,11 @@ def iff(self, *args, **kwargs): raise QueryException('{0} is not a valid query operator'.format(operator)) clone._conditional.append(operator) - for col_name, val in kwargs.items(): + for arg, val in kwargs.items(): if isinstance(val, Token): raise QueryException("Token() values are not valid in conditionals") + col_name, col_op = self._parse_filter_arg(arg) try: column = self.model._get_column(col_name) except KeyError: @@ -574,7 +673,9 @@ def iff(self, *args, **kwargs): else: query_val = column.to_database(val) - clone._conditional.append(ConditionalClause(col_name, query_val)) + operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator = operator_class() + clone._conditional.append(WhereClause(column.db_field_name, operator, query_val)) return clone @@ -636,7 +737,7 @@ def filter(self, *args, **kwargs): else: query_val = column.to_database(val) if not col_op: # only equal values should be deferred - clone._defer_fields.add(col_name) + clone._defer_fields.add(column.db_field_name) clone._deferred_values[column.db_field_name] = val # map by db field name for substitution in results clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) @@ -709,11 +810,11 @@ class Comment(Model): print("Normal") for comment in Comment.objects(photo_id=u): - print comment.comment_id + print(comment.comment_id) print("Reversed") for comment in Comment.objects(photo_id=u).order_by("-comment_id"): - print comment.comment_id + print(comment.comment_id) """ if len(colnames) == 0: clone = copy.deepcopy(self) @@ -741,7 +842,7 @@ def count(self): query = self._select_query() query.count = True result = self._execute(query) - count_row = result[0].popitem() + count_row = result.one().popitem() self._count = count_row[1] return self._count @@ -801,7 +902,7 @@ def limit(self, v): if v is None: v = 0 - if not isinstance(v, six.integer_types): + if not isinstance(v, int): raise TypeError if v == self._limit: return self @@ -825,7 +926,7 @@ def fetch_size(self, v): print(user) """ - if not isinstance(v, six.integer_types): + if not isinstance(v, int): raise TypeError if v == self._fetch_size: return self @@ -839,7 +940,7 @@ def fetch_size(self, v): def allow_filtering(self): """ - Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key + Enables the (usually) unwise practice of querying on a clustering key without also defining a partition key """ clone = copy.deepcopy(self) clone._allow_filtering = True @@ -858,6 +959,8 @@ def _only_or_defer(self, action, fields): "Can't resolve fields {0} in {1}".format( ', '.join(missing_fields), self.model.__name__)) + fields = [self.model._columns[field].db_field_name for field in fields] + if action == 'defer': clone._defer_fields.update(fields) elif action == 'only': @@ -883,6 +986,7 @@ def create(self, **kwargs): .if_not_exists(self._if_not_exists) \ .timestamp(self._timestamp) \ .if_exists(self._if_exists) \ + .using(connection=self._connection) \ .save() def delete(self): @@ -920,6 +1024,24 @@ def timeout(self, timeout): clone._timeout = timeout return clone + def using(self, keyspace=None, connection=None): + """ + Change the context on-the-fly of the Model class (keyspace, connection) + """ + + if connection and self._batch: + raise CQLEngineException("Cannot specify a connection on model in batch mode.") + + clone = copy.deepcopy(self) + if keyspace: + from cassandra.cqlengine.models import _clone_model_class + clone.model = _clone_model_class(self.model, {'__keyspace__': keyspace}) + + if connection: + clone._connection = connection + + return clone + class ResultObject(dict): """ @@ -950,27 +1072,48 @@ class ModelQuerySet(AbstractQuerySet): """ def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ - # check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field - equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + # check that there's either a =, a IN or a CONTAINS (collection) + # relationship with a primary key or indexed field. We also allow + # custom indexes to be queried with any operator (a difference + # between a secondary index) + equal_ops = [self.model._get_column_by_db_name(w.field) \ + for w in self._where if not isinstance(w.value, Token) + and (isinstance(w.operator, EqualsOperator) + or self.model._get_column_by_db_name(w.field).custom_index)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) - if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering: - raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) ' - 'comparison with either a primary key or indexed field')) + if not any(w.primary_key or w.has_index for w in equal_ops) and not token_comparison and not self._allow_filtering: + raise QueryException( + ('Where clauses require either =, a IN or a CONTAINS ' + '(collection) comparison with either a primary key or ' + 'indexed field. You might want to consider setting ' + 'custom_index on fields that you manage index outside ' + 'cqlengine.')) if not self._allow_filtering: # if the query is not on an indexed field - if not any(w.index for w in equal_ops): + if not any(w.has_index for w in equal_ops): if not any([w.partition_key for w in equal_ops]) and not token_comparison: - raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') + raise QueryException( + ('Filtering on a clustering key without a partition ' + 'key is not allowed unless allow_filtering() is ' + 'called on the queryset. You might want to consider ' + 'setting custom_index on fields that you manage ' + 'index outside cqlengine.')) def _select_fields(self): if self._defer_fields or self._only_fields: - fields = self.model._columns.keys() + fields = [columns.db_field_name for columns in self.model._columns.values()] if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] - elif self._only_fields: - fields = self._only_fields - return [self.model._columns[f].db_field_name for f in fields] + # select the partition keys if all model fields are set defer + if not fields: + fields = [columns.db_field_name for columns in self.model._partition_keys.values()] + if self._only_fields: + fields = [f for f in fields if f in self._only_fields] + if not fields: + raise QueryException('No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( + ','.join(self._only_fields), ','.join(self._defer_fields))) + return fields return super(ModelQuerySet, self)._select_fields() def _get_result_constructor(self): @@ -1143,11 +1286,15 @@ class Row(Model): # add items to a map Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) + + # remove items from a map + Row.objects(row_id=5).update(map_column__remove={1, 2}) """ if not values: return nulled_columns = set() + updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): @@ -1160,21 +1307,30 @@ class Row(Model): if col.is_primary_key: raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(col_name, self.__module__, self.model.__name__)) - # we should not provide default values in this use case. - val = col.validate(val) + if col_op == 'remove' and isinstance(col, columns.Map): + if not isinstance(val, set): + raise ValidationError( + "Cannot apply update operation '{0}' on column '{1}' with value '{2}'. A set is required.".format(col_op, col_name, val)) + val = {v: None for v in val} + else: + # we should not provide default values in this use case. + val = col.validate(val) if val is None: nulled_columns.add(col_name) continue us.add_update(col, val, operation=col_op) + updated_columns.add(col_name) if us.assignments: self._execute(us) if nulled_columns: + delete_conditional = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None ds = DeleteStatement(self.column_family_name, fields=nulled_columns, - where=self._where, conditionals=self._conditional, if_exists=self._if_exists) + where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) self._execute(ds) @@ -1193,7 +1349,7 @@ class DMLQuery(object): _if_exists = False def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, - if_not_exists=False, conditional=None, timeout=connection.NOT_SET, if_exists=False): + if_not_exists=False, conditional=None, timeout=conn.NOT_SET, if_exists=False): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance @@ -1206,14 +1362,22 @@ def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, self._conditional = conditional self._timeout = timeout - def _execute(self, q): + def _execute(self, statement): + connection = self.instance._get_connection() if self.instance else self.model._get_connection() if self._batch: - return self._batch.add_query(q) + if self._batch._connection: + if not self._batch._connection_explicit and connection and \ + connection != self._batch._connection: + raise CQLEngineException('BatchQuery queries must be executed on the same connection') + else: + # set the BatchQuery connection from the model + self._batch._connection = connection + return self._batch.add_query(statement) else: - tmp = connection.execute(q, consistency_level=self._consistency, timeout=self._timeout) + results = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: - check_applied(tmp) - return tmp + check_applied(results) + return results def batch(self, batch_obj): if batch_obj is not None and not isinstance(batch_obj, BatchQuery): @@ -1221,30 +1385,30 @@ def batch(self, batch_obj): self._batch = batch_obj return self - def _delete_null_columns(self): + def _delete_null_columns(self, conditionals=None): """ executes a delete query to remove columns that have changed to null """ - ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists) + ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) deleted_fields = False + static_only = True for _, v in self.instance._values.items(): col = v.column if v.deleted: ds.add_field(col.db_field_name) deleted_fields = True + static_only &= col.static elif isinstance(col, columns.Map): uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) if uc.get_context_size() > 0: ds.add_field(uc) deleted_fields = True + static_only |= col.static if deleted_fields: - for name, col in self.model._primary_keys.items(): - ds.add_where_clause(WhereClause( - col.db_field_name, - EqualsOperator(), - col.to_database(getattr(self.instance, name)) - )) + keys = self.model._partition_keys if static_only else self.model._primary_keys + for name, col in keys.items(): + ds.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(ds) def update(self): @@ -1255,7 +1419,7 @@ def update(self): prior to calling this. """ if self.instance is None: - raise CQLEngineException("DML Query intance attribute is None") + raise CQLEngineException("DML Query instance attribute is None") assert type(self.instance) == self.model null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True static_changed_only = True @@ -1263,9 +1427,11 @@ def update(self): conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + + updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): - # if clustering key is null, don't include non static columns + # if clustering key is null, don't include non-static columns if null_clustering_key and not col.static and not col.partition_key: continue if not col.is_primary_key: @@ -1280,21 +1446,21 @@ def update(self): static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) + updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): - # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error + # only include clustering key if clustering key is not null, and non-static columns are changed to avoid cql error if (null_clustering_key or static_changed_only) and (not col.partition_key): continue - statement.add_where_clause(WhereClause( - col.db_field_name, - EqualsOperator(), - col.to_database(getattr(self.instance, name)) - )) + statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(statement) if not null_clustering_key: - self._delete_null_columns() + # remove conditions on fields that have been updated + delete_conditionals = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None + self._delete_null_columns(delete_conditionals) def save(self): """ @@ -1304,13 +1470,14 @@ def save(self): prior to calling this. """ if self.instance is None: - raise CQLEngineException("DML Query intance attribute is None") + raise CQLEngineException("DML Query instance attribute is None") assert type(self.instance) == self.model nulled_fields = set() if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter: - warn("'create' and 'save' actions on Counters are deprecated. A future version will disallow this. Use the 'update' mechanism instead.") + warn("'create' and 'save' actions on Counters are deprecated. It will be disallowed in 4.0. " + "Use the 'update' mechanism instead.", DeprecationWarning) return self.update() else: insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) @@ -1325,10 +1492,10 @@ def save(self): if self.instance._values[name].changed: nulled_fields.add(col.db_field_name) continue - insert.add_assignment_clause(AssignmentClause( - col.db_field_name, - col.to_database(getattr(self.instance, name, None)) - )) + if col.has_default and not self.instance._values[name].changed: + # Ensure default columns included in a save() are marked as explicit, to get them *persisted* properly + self.instance._values[name].explicit = True + insert.add_assignment(col, getattr(self.instance, name, None)) # skip query execution if it's empty # caused by pointless update queries @@ -1345,12 +1512,21 @@ def delete(self): ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.model._primary_keys.items(): - if (not col.partition_key) and (getattr(self.instance, name) is None): + val = getattr(self.instance, name) + if val is None and not col.partition_key: continue - - ds.add_where_clause(WhereClause( - col.db_field_name, - EqualsOperator(), - col.to_database(getattr(self.instance, name)) - )) + ds.add_where(col, EqualsOperator(), val) self._execute(ds) + + +def _execute_statement(model, statement, consistency_level, timeout, connection=None): + params = statement.get_context() + s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size) + if model._partition_key_index: + key_values = statement.partition_key_values(model._partition_key_index) + if not any(v is None for v in key_values): + parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) + s.routing_key = parts + s.keyspace = model._get_keyspace() + connection = connection or model._get_connection() + return conn.execute(s, params, timeout=timeout, connection=connection) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index f5f626a49b..b20b07ef56 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,13 +16,12 @@ from datetime import datetime, timedelta import time -import six from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine import columns from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue -from cassandra.cqlengine.operators import BaseWhereOperator, InOperator +from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator class StatementException(Exception): @@ -34,9 +35,7 @@ def __init__(self, value): def __unicode__(self): from cassandra.encoder import cql_quote - if isinstance(self.value, bool): - return 'true' if self.value else 'false' - elif isinstance(self.value, (list, tuple)): + if isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' @@ -115,7 +114,7 @@ def __init__(self, field, operator, value, quote_field=True): def __unicode__(self): field = ('"{0}"' if self.quote_field else '{0}').format(self.field) - return u'{0} {1} {2}'.format(field, self.operator, six.text_type(self.query_value)) + return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value)) def __hash__(self): return super(WhereClause, self).__hash__() ^ hash(self.operator) @@ -139,6 +138,24 @@ def update_context(self, ctx): self.query_value.update_context(ctx) +class IsNotNullClause(WhereClause): + def __init__(self, field): + super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '') + + def __unicode__(self): + field = ('"{0}"' if self.quote_field else '{0}').format(self.field) + return u'{0} {1}'.format(field, self.operator) + + def update_context(self, ctx): + pass + + def get_context_size(self): + return 0 + +# alias for convenience +IsNotNull = IsNotNullClause + + class AssignmentClause(BaseClause): """ a single variable st statement """ @@ -169,8 +186,7 @@ def __init__(cls, name, bases, dct): super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct) -@six.add_metaclass(ContainerUpdateTypeMapMeta) -class ContainerUpdateClause(AssignmentClause): +class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta): def __init__(self, field, value, operation=None, previous=None): super(ContainerUpdateClause, self).__init__(field, value) @@ -360,10 +376,13 @@ class MapUpdateClause(ContainerUpdateClause): col_type = columns.Map _updates = None + _removals = None def _analyze(self): if self._operation == "update": self._updates = self.value.keys() + elif self._operation == "remove": + self._removals = {v for v in self.value.keys()} else: if self.previous is None: self._updates = sorted([k for k, v in self.value.items()]) @@ -374,12 +393,14 @@ def _analyze(self): def get_context_size(self): if self.is_assignment: return 1 - return len(self._updates or []) * 2 + return int((len(self._updates or []) * 2) + int(bool(self._removals))) def update_context(self, ctx): ctx_id = self.context_id if self.is_assignment: ctx[str(ctx_id)] = {} + elif self._removals is not None: + ctx[str(ctx_id)] = self._removals else: for key in self._updates or []: val = self.value.get(key) @@ -391,7 +412,7 @@ def update_context(self, ctx): def is_assignment(self): if not self._analyzed: self._analyze() - return self.previous is None and not self._updates + return self.previous is None and not self._updates and not self._removals def __unicode__(self): qs = [] @@ -399,6 +420,9 @@ def __unicode__(self): ctx_id = self.context_id if self.is_assignment: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] + elif self._removals is not None: + qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 else: for _ in self._updates or []: qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] @@ -481,10 +505,9 @@ def __unicode__(self): class BaseCQLStatement(UnicodeMixin): """ The base cql statement class """ - def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_size=None, conditionals=None): + def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): super(BaseCQLStatement, self).__init__() self.table = table - self.consistency = consistency self.context_id = 0 self.context_counter = self.context_id self.timestamp = timestamp @@ -492,20 +515,27 @@ def __init__(self, table, consistency=None, timestamp=None, where=None, fetch_si self.where_clauses = [] for clause in where or []: - self.add_where_clause(clause) + self._add_where_clause(clause) self.conditionals = [] for conditional in conditionals or []: self.add_conditional_clause(conditional) - def add_where_clause(self, clause): - """ - adds a where clause to this statement - :param clause: the clause to add - :type clause: WhereClause - """ - if not isinstance(clause, WhereClause): - raise StatementException("only instances of WhereClause can be added to statements") + def _update_part_key_values(self, field_index_map, clauses, parts): + for clause in filter(lambda c: c.field in field_index_map, clauses): + parts[field_index_map[clause.field]] = clause.value + + def partition_key_values(self, field_index_map): + parts = [None] * len(field_index_map) + self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) + return parts + + def add_where(self, column, operator, value, quote_field=True): + value = column.to_database(value) + clause = WhereClause(column.db_field_name, operator, value, quote_field) + self._add_where_clause(clause) + + def _add_where_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.where_clauses.append(clause) @@ -522,19 +552,17 @@ def get_context(self): def add_conditional_clause(self, clause): """ - Adds a iff clause to this statement + Adds an iff clause to this statement :param clause: The clause that will be added to the iff statement :type clause: ConditionalClause """ - if not isinstance(clause, ConditionalClause): - raise StatementException('only instances of AssignmentClause can be added to statements') clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.conditionals.append(clause) def _get_conditionals(self): - return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals])) + return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals])) def get_context_size(self): return len(self.get_context()) @@ -549,13 +577,13 @@ def update_context_id(self, i): @property def timestamp_normalized(self): """ - we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta + We're expecting self.timestamp to be either a long, int, a datetime, or a timedelta :return: """ if not self.timestamp: return None - if isinstance(self.timestamp, six.integer_types): + if isinstance(self.timestamp, int): return self.timestamp if isinstance(self.timestamp, timedelta): @@ -573,7 +601,7 @@ def __repr__(self): @property def _where(self): - return 'WHERE {0}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) + return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses])) class SelectStatement(BaseCQLStatement): @@ -583,7 +611,6 @@ def __init__(self, table, fields=None, count=False, - consistency=None, where=None, order_by=None, limit=None, @@ -597,15 +624,14 @@ def __init__(self, """ super(SelectStatement, self).__init__( table, - consistency=consistency, where=where, fetch_size=fetch_size ) - self.fields = [fields] if isinstance(fields, six.string_types) else (fields or []) + self.fields = [fields] if isinstance(fields, str) else (fields or []) self.distinct_fields = distinct_fields self.count = count - self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by + self.order_by = [order_by] if isinstance(order_by, str) else order_by self.limit = limit self.allow_filtering = allow_filtering @@ -626,7 +652,7 @@ def __unicode__(self): qs += [self._where] if self.order_by and not self.count: - qs += ['ORDER BY {0}'.format(', '.join(six.text_type(o) for o in self.order_by))] + qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))] if self.limit: qs += ['LIMIT {0}'.format(self.limit)] @@ -643,14 +669,12 @@ class AssignmentStatement(BaseCQLStatement): def __init__(self, table, assignments=None, - consistency=None, where=None, ttl=None, timestamp=None, conditionals=None): super(AssignmentStatement, self).__init__( table, - consistency=consistency, where=where, conditionals=conditionals ) @@ -660,7 +684,7 @@ def __init__(self, # add assignments self.assignments = [] for assignment in assignments or []: - self.add_assignment_clause(assignment) + self._add_assignment_clause(assignment) def update_context_id(self, i): super(AssignmentStatement, self).update_context_id(i) @@ -668,14 +692,17 @@ def update_context_id(self, i): assignment.set_context_id(self.context_counter) self.context_counter += assignment.get_context_size() - def add_assignment_clause(self, clause): - """ - adds an assignment clause to this statement - :param clause: the clause to add - :type clause: AssignmentClause - """ - if not isinstance(clause, AssignmentClause): - raise StatementException("only instances of AssignmentClause can be added to statements") + def partition_key_values(self, field_index_map): + parts = super(AssignmentStatement, self).partition_key_values(field_index_map) + self._update_part_key_values(field_index_map, self.assignments, parts) + return parts + + def add_assignment(self, column, value): + value = column.to_database(value) + clause = AssignmentClause(column.db_field_name, value) + self._add_assignment_clause(clause) + + def _add_assignment_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.assignments.append(clause) @@ -697,23 +724,18 @@ class InsertStatement(AssignmentStatement): def __init__(self, table, assignments=None, - consistency=None, where=None, ttl=None, timestamp=None, if_not_exists=False): super(InsertStatement, self).__init__(table, assignments=assignments, - consistency=consistency, where=where, ttl=ttl, timestamp=timestamp) self.if_not_exists = if_not_exists - def add_where_clause(self, clause): - raise StatementException("Cannot add where clauses to insert statements") - def __unicode__(self): qs = ['INSERT INTO {0}'.format(self.table)] @@ -728,12 +750,15 @@ def __unicode__(self): if self.if_not_exists: qs += ["IF NOT EXISTS"] + using_options = [] if self.ttl: - qs += ["USING TTL {0}".format(self.ttl)] + using_options += ["TTL {}".format(self.ttl)] if self.timestamp: - qs += ["USING TIMESTAMP {0}".format(self.timestamp_normalized)] + using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] + if using_options: + qs += ["USING {}".format(" AND ".join(using_options))] return ' '.join(qs) @@ -743,7 +768,6 @@ class UpdateStatement(AssignmentStatement): def __init__(self, table, assignments=None, - consistency=None, where=None, ttl=None, timestamp=None, @@ -751,7 +775,6 @@ def __init__(self, if_exists=False): super(UpdateStatement, self). __init__(table, assignments=assignments, - consistency=consistency, where=where, ttl=ttl, timestamp=timestamp, @@ -774,7 +797,7 @@ def __unicode__(self): qs += ["USING {0}".format(" AND ".join(using_options))] qs += ['SET'] - qs += [', '.join([six.text_type(c) for c in self.assignments])] + qs += [', '.join([str(c) for c in self.assignments])] if self.where_clauses: qs += [self._where] @@ -807,26 +830,25 @@ def add_update(self, column, value, operation=None, previous=None): previous = column.to_database(previous) clause = container_update_type(column.db_field_name, value, operation, previous) elif col_type == columns.Counter: - clause = CounterUpdateClause(column.db_field_name, value) + clause = CounterUpdateClause(column.db_field_name, value, previous) else: clause = AssignmentClause(column.db_field_name, value) if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates - self.add_assignment_clause(clause) + self._add_assignment_clause(clause) class DeleteStatement(BaseCQLStatement): """ a cql delete statement """ - def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None, conditionals=None, if_exists=False): + def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): super(DeleteStatement, self).__init__( table, - consistency=consistency, where=where, timestamp=timestamp, conditionals=conditionals ) self.fields = [] - if isinstance(fields, six.string_types): + if isinstance(fields, str): fields = [fields] for field in fields or []: self.add_field(field) @@ -851,7 +873,7 @@ def get_context(self): return ctx def add_field(self, field): - if isinstance(field, six.string_types): + if isinstance(field, str): field = FieldDeleteClause(field) if not isinstance(field, BaseClause): raise StatementException("only instances of AssignmentClause can be added to statements") diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py index 7d753e898d..e96534f9c6 100644 --- a/cassandra/cqlengine/usertype.py +++ b/cassandra/cqlengine/usertype.py @@ -1,10 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re -import six from cassandra.util import OrderedDict from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine import columns -from cassandra.cqlengine import connection +from cassandra.cqlengine import connection as conn from cassandra.cqlengine import models @@ -31,7 +46,8 @@ def __init__(self, **values): values = dict((self._db_map.get(k, k), v) for k, v in values.items()) for name, field in self._fields.items(): - value = values.get(name, None) + field_default = field.get_default() if field.has_default else None + value = values.get(name, field_default) if value is not None or isinstance(field, columns.BaseContainerColumn): value = field.to_python(value) value_mngr = field.value_manager(self, field, value) @@ -57,7 +73,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values))) + return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in self._values.items())) def has_changed_fields(self): return any(v.changed for v in self._values.values()) @@ -78,14 +94,14 @@ def __getattr__(self, attr): raise AttributeError(attr) def __getitem__(self, key): - if not isinstance(key, six.string_types): + if not isinstance(key, str): raise TypeError if key not in self._fields.keys(): raise KeyError return getattr(self, key) def __setitem__(self, key, val): - if not isinstance(key, six.string_types): + if not isinstance(key, str): raise TypeError if key not in self._fields.keys(): raise KeyError @@ -111,8 +127,8 @@ def items(self): return [(k, self[k]) for k in self] @classmethod - def register_for_keyspace(cls, keyspace): - connection.register_udt(keyspace, cls.type_name(), cls) + def register_for_keyspace(cls, keyspace, connection=None): + conn.register_udt(keyspace, cls.type_name(), cls, connection=connection) @classmethod def type_name(cls): @@ -183,8 +199,7 @@ def _transform_column(field_name, field_obj): return klass -@six.add_metaclass(UserTypeMetaClass) -class UserType(BaseUserType): +class UserType(BaseUserType, metaclass=UserTypeMetaClass): """ This class is used to model User Defined Types. To define a type, declare a class inheriting from this, and assign field types as class attributes: diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 26d7ebc0ff..5e063a0141 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -28,29 +30,32 @@ # .from_cql_literal() and .as_cql_literal() classmethods (or whatever). from __future__ import absolute_import # to enable import io from stdlib +import ast from binascii import unhexlify import calendar from collections import namedtuple from decimal import Decimal import io +from itertools import chain import logging import re import socket import time -import six -from six.moves import range +import struct import sys from uuid import UUID -import warnings - from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, - varint_pack, varint_unpack) + varint_pack, varint_unpack, point_be, point_le, + vints_pack, vints_unpack, uvint_unpack, uvint_pack) from cassandra import util +_little_endian_flag = 1 # we always serialize LE +import ipaddress + apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' @@ -58,16 +63,12 @@ log = logging.getLogger(__name__) -if six.PY3: - _number_types = frozenset((int, float)) - long = int +_number_types = frozenset((int, float)) - def _name_from_hex_string(encoded_name): - bin_str = unhexlify(encoded_name) - return bin_str.decode('ascii') -else: - _number_types = frozenset((int, long, float)) - _name_from_hex_string = unhexlify + +def _name_from_hex_string(encoded_name): + bin_str = unhexlify(encoded_name) + return bin_str.decode('ascii') def trim_if_startswith(s, prefix): @@ -77,6 +78,7 @@ def trim_if_startswith(s, prefix): _casstypes = {} +_cqltypes = {} cql_type_scanner = re.Scanner(( @@ -106,6 +108,8 @@ def __new__(metacls, name, bases, dct): cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls + if not cls.typename.startswith(apache_cassandra_type_prefix): + _cqltypes[cls.typename] = cls return cls @@ -116,6 +120,73 @@ def __new__(metacls, name, bases, dct): )) +def cqltype_to_python(cql_string): + """ + Given a cql type string, creates a list that can be manipulated in python + Example: + int -> ['int'] + frozen> -> ['frozen', ['tuple', ['text', 'int']]] + """ + scanner = re.Scanner(( + (r'[a-zA-Z0-9_]+', lambda s, t: "'{}'".format(t)), + (r'<', lambda s, t: ', ['), + (r'>', lambda s, t: ']'), + (r'[, ]', lambda s, t: t), + (r'".*?"', lambda s, t: "'{}'".format(t)), + )) + + scanned_tokens = scanner.scan(cql_string)[0] + hierarchy = ast.literal_eval(''.join(scanned_tokens)) + return [hierarchy] if isinstance(hierarchy, str) else list(hierarchy) + + +def python_to_cqltype(types): + """ + Opposite of the `cql_to_python` function. Given a python list, creates a cql type string from the representation + Example: + ['int'] -> int + ['frozen', ['tuple', ['text', 'int']]] -> frozen> + """ + scanner = re.Scanner(( + (r"'[a-zA-Z0-9_]+'", lambda s, t: t[1:-1]), + (r'^\[', lambda s, t: None), + (r'\]$', lambda s, t: None), + (r',\s*\[', lambda s, t: '<'), + (r'\]', lambda s, t: '>'), + (r'[, ]', lambda s, t: t), + (r'\'".*?"\'', lambda s, t: t[1:-1]), + )) + + scanned_tokens = scanner.scan(repr(types))[0] + cql = ''.join(scanned_tokens).replace('\\\\', '\\') + return cql + + +def _strip_frozen_from_python(types): + """ + Given a python list representing a cql type, removes 'frozen' + Example: + ['frozen', ['tuple', ['text', 'int']]] -> ['tuple', ['text', 'int']] + """ + while 'frozen' in types: + index = types.index('frozen') + types = types[:index] + types[index + 1] + types[index + 2:] + new_types = [_strip_frozen_from_python(item) if isinstance(item, list) else item for item in types] + return new_types + + +def strip_frozen(cql): + """ + Given a cql type string, and removes frozen + Example: + frozen> -> tuple + """ + types = cqltype_to_python(cql) + types_without_frozen = _strip_frozen_from_python(types) + cql = python_to_cqltype(types_without_frozen) + return cql + + def lookup_casstype_simple(casstype): """ Given a Cassandra type name (either fully distinguished or not), hand @@ -157,13 +228,15 @@ def parse_casstype_args(typestring): else: names.append(None) - ctype = lookup_casstype_simple(tok) + try: + ctype = int(tok) + except ValueError: + ctype = lookup_casstype_simple(tok) types.append(ctype) # return the first (outer) type, which will have all parameters applied return args[0][0][0] - def lookup_casstype(casstype): """ Given a Cassandra type as a string (possibly including parameters), hand @@ -173,7 +246,7 @@ def lookup_casstype(casstype): Example: >>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)') - + """ if isinstance(casstype, (CassandraType, CassandraTypeType)): @@ -198,8 +271,7 @@ def __str__(self): EMPTY = EmptyValue() -@six.add_metaclass(CassandraTypeType) -class _CassandraType(object): +class _CassandraType(object, metaclass=CassandraTypeType): subtypes = () num_subtypes = 0 empty_binary_ok = False @@ -218,7 +290,7 @@ class _CassandraType(object): """ def __repr__(self): - return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) + return '<%s>' % (self.cql_parameterized_type()) @classmethod def from_binary(cls, byts, protocol_version): @@ -293,7 +365,7 @@ def apply_parameters(cls, subtypes, names=None): using them as parameters. This is how composite types are constructed. >>> MapType.apply_parameters([DateType, BooleanType]) - + `subtypes` will be a sequence of CassandraTypes. If provided, `names` will be an equally long sequence of column names or Nones. @@ -302,8 +374,6 @@ def apply_parameters(cls, subtypes, names=None): raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) newname = cls.cass_parameterized_type_with(subtypes) - if six.PY2 and isinstance(newname, unicode): - newname = newname.encode('utf-8') return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) @classmethod @@ -324,6 +394,9 @@ def cass_parameterized_type(cls, full=False): """ return cls.cass_parameterized_type_with(cls.subtypes, full=full) + @classmethod + def serial_size(cls): + return None # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc @@ -334,16 +407,10 @@ class _UnrecognizedType(_CassandraType): num_subtypes = 'UNKNOWN' -if six.PY3: - def mkUnrecognizedType(casstypename): - return CassandraTypeType(casstypename, - (_UnrecognizedType,), - {'typename': "'%s'" % casstypename}) -else: - def mkUnrecognizedType(casstypename): # noqa - return CassandraTypeType(casstypename.encode('utf8'), - (_UnrecognizedType,), - {'typename': "'%s'" % casstypename}) +def mkUnrecognizedType(casstypename): + return CassandraTypeType(casstypename, + (_UnrecognizedType,), + {'typename': "'%s'" % casstypename}) class BytesType(_CassandraType): @@ -352,7 +419,7 @@ class BytesType(_CassandraType): @staticmethod def serialize(val, protocol_version): - return six.binary_type(val) + return bytes(val) class DecimalType(_CassandraType): @@ -395,6 +462,9 @@ def serialize(uuid, protocol_version): except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") + @classmethod + def serial_size(cls): + return 16 class BooleanType(_CassandraType): typename = 'boolean' @@ -407,6 +477,10 @@ def deserialize(byts, protocol_version): def serialize(truth, protocol_version): return int8_pack(truth) + @classmethod + def serial_size(cls): + return 1 + class ByteType(_CassandraType): typename = 'tinyint' @@ -419,25 +493,20 @@ def serialize(byts, protocol_version): return int8_pack(byts) -if six.PY2: - class AsciiType(_CassandraType): - typename = 'ascii' - empty_binary_ok = True -else: - class AsciiType(_CassandraType): - typename = 'ascii' - empty_binary_ok = True +class AsciiType(_CassandraType): + typename = 'ascii' + empty_binary_ok = True - @staticmethod - def deserialize(byts, protocol_version): - return byts.decode('ascii') + @staticmethod + def deserialize(byts, protocol_version): + return byts.decode('ascii') - @staticmethod - def serialize(var, protocol_version): - try: - return var.encode('ascii') - except UnicodeDecodeError: - return var + @staticmethod + def serialize(var, protocol_version): + try: + return var.encode('ascii') + except UnicodeDecodeError: + return var class FloatType(_CassandraType): @@ -451,6 +520,9 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return float_pack(byts) + @classmethod + def serial_size(cls): + return 4 class DoubleType(_CassandraType): typename = 'double' @@ -463,6 +535,9 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return double_pack(byts) + @classmethod + def serial_size(cls): + return 8 class LongType(_CassandraType): typename = 'bigint' @@ -475,6 +550,9 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return int64_pack(byts) + @classmethod + def serial_size(cls): + return 8 class Int32Type(_CassandraType): typename = 'int' @@ -487,6 +565,9 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return int32_pack(byts) + @classmethod + def serial_size(cls): + return 4 class IntegerType(_CassandraType): typename = 'varint' @@ -514,12 +595,17 @@ def deserialize(byts, protocol_version): @staticmethod def serialize(addr, protocol_version): - if ':' in addr: - return util.inet_pton(socket.AF_INET6, addr) - else: - # util.inet_pton could also handle, but this is faster - # since we've already determined the AF - return socket.inet_aton(addr) + try: + if ':' in addr: + return util.inet_pton(socket.AF_INET6, addr) + else: + # util.inet_pton could also handle, but this is faster + # since we've already determined the AF + return socket.inet_aton(addr) + except: + if isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)): + return addr.packed + raise ValueError("can't interpret %r as an inet address" % (addr,)) class CounterColumnType(LongType): @@ -576,8 +662,11 @@ def serialize(v, protocol_version): raise TypeError('DateType arguments must be a datetime, date, or timestamp') timestamp = v - return int64_pack(long(timestamp)) + return int64_pack(int(timestamp)) + @classmethod + def serial_size(cls): + return 8 class TimestampType(DateType): pass @@ -600,6 +689,9 @@ def serialize(timeuuid, protocol_version): except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") + @classmethod + def serial_size(cls): + return 16 class SimpleDateType(_CassandraType): typename = 'date' @@ -620,6 +712,11 @@ def serialize(val, protocol_version): try: days = val.days_from_epoch except AttributeError: + if isinstance(val, int): + # the DB wants offset int values, but util.Date init takes days from epoch + # here we assume int values are offset, as they would appear in CQL + # short circuit to avoid subtracting just to add offset + return uint32_pack(val) days = util.Date(val).days_from_epoch return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS) @@ -635,9 +732,14 @@ def deserialize(byts, protocol_version): def serialize(byts, protocol_version): return int16_pack(byts) - class TimeType(_CassandraType): typename = 'time' + # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as + # variable size... and we have to match what the server expects since the server + # uses that specification to encode data of that type. + #@classmethod + #def serial_size(cls): + # return 8 @staticmethod def deserialize(byts, protocol_version): @@ -652,6 +754,23 @@ def serialize(val, protocol_version): return int64_pack(nano) +class DurationType(_CassandraType): + typename = 'duration' + + @staticmethod + def deserialize(byts, protocol_version): + months, days, nanoseconds = vints_unpack(byts) + return util.Duration(months, days, nanoseconds) + + @staticmethod + def serialize(duration, protocol_version): + try: + m, d, n = duration.months, duration.days, duration.nanoseconds + except AttributeError: + raise TypeError('DurationType arguments must be a Duration.') + return vints_pack([m, d, n]) + + class UTF8Type(_CassandraType): typename = 'text' empty_binary_ok = True @@ -674,6 +793,8 @@ class VarcharType(UTF8Type): class _ParameterizedType(_CassandraType): + num_subtypes = 'UNKNOWN' + @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: @@ -706,14 +827,17 @@ def deserialize_safe(cls, byts, protocol_version): for _ in range(numelements): itemlen = unpack(byts[p:p + length]) p += length - item = byts[p:p + itemlen] - p += itemlen - result.append(subtype.from_binary(item, inner_proto)) + if itemlen < 0: + result.append(None) + else: + item = byts[p:p + itemlen] + p += itemlen + result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @classmethod def serialize_safe(cls, items, protocol_version): - if isinstance(items, six.string_types): + if isinstance(items, str): raise TypeError("Received a string for a type that expects a sequence") subtype, = cls.subtypes @@ -760,14 +884,23 @@ def deserialize_safe(cls, byts, protocol_version): for _ in range(numelements): key_len = unpack(byts[p:p + length]) p += length - keybytes = byts[p:p + key_len] - p += key_len + if key_len < 0: + keybytes = None + key = None + else: + keybytes = byts[p:p + key_len] + p += key_len + key = key_type.from_binary(keybytes, inner_proto) + val_len = unpack(byts[p:p + length]) p += length - valbytes = byts[p:p + val_len] - p += val_len - key = key_type.from_binary(keybytes, inner_proto) - val = value_type.from_binary(valbytes, inner_proto) + if val_len < 0: + val = None + else: + valbytes = byts[p:p + val_len] + p += val_len + val = value_type.from_binary(valbytes, inner_proto) + themap._insert_unchecked(key, keybytes, val) return themap @@ -778,7 +911,7 @@ def serialize_safe(cls, themap, protocol_version): buf = io.BytesIO() buf.write(pack(len(themap))) try: - items = six.iteritems(themap) + items = themap.items() except AttributeError: raise TypeError("Got a non-map object for a map value") inner_proto = max(3, protocol_version) @@ -794,7 +927,6 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): typename = 'tuple' - num_subtypes = 'UNKNOWN' @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -845,7 +977,7 @@ def cql_parameterized_type(cls): class UserType(TupleType): - typename = "'org.apache.cassandra.db.marshal.UserType'" + typename = "org.apache.cassandra.db.marshal.UserType" _cache = {} _module = sys.modules[__name__] @@ -854,9 +986,6 @@ class UserType(TupleType): def make_udt_class(cls, keyspace, udt_name, field_names, field_types): assert len(field_names) == len(field_types) - if six.PY2 and isinstance(udt_name, unicode): - udt_name = udt_name.encode('utf-8') - instance = cls._cache.get((keyspace, udt_name)) if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: instance = type(udt_name, (cls,), {'subtypes': field_types, @@ -871,8 +1000,6 @@ def make_udt_class(cls, keyspace, udt_name, field_names, field_types): @classmethod def evict_udt_class(cls, keyspace, udt_name): - if six.PY2 and isinstance(udt_name, unicode): - udt_name = udt_name.encode('utf-8') try: del cls._cache[(keyspace, udt_name)] except KeyError: @@ -908,7 +1035,9 @@ def serialize_safe(cls, val, protocol_version): try: item = val[i] except TypeError: - item = getattr(val, fieldname) + item = getattr(val, fieldname, None) + if item is None and not hasattr(val, fieldname): + log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") if item is not None: packed_item = subtype.to_binary(item, proto_version) @@ -938,18 +1067,19 @@ def _make_udt_tuple_type(cls, name, field_names): except ValueError: try: t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) - log.warn("could not create a namedtuple for '%s' because one or more field names are not valid Python identifiers (%s); " \ - "returning positionally-named fields" % (name, field_names)) + log.warning("could not create a namedtuple for '%s' because one or more " + "field names are not valid Python identifiers (%s); " + "returning positionally-named fields" % (name, field_names)) except ValueError: t = None - log.warn("could not create a namedtuple for '%s' because the name is not a valid Python identifier; " \ - "will return tuples in its place" % (name,)) + log.warning("could not create a namedtuple for '%s' because the name is " + "not a valid Python identifier; will return tuples in " + "its place" % (name,)) return t class CompositeType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.CompositeType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.CompositeType" @classmethod def cql_parameterized_type(cls): @@ -977,8 +1107,13 @@ def deserialize_safe(cls, byts, protocol_version): return tuple(result) -class DynamicCompositeType(CompositeType): - typename = "'org.apache.cassandra.db.marshal.DynamicCompositeType'" +class DynamicCompositeType(_ParameterizedType): + typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" + + @classmethod + def cql_parameterized_type(cls): + sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) + return "'%s(%s)'" % (cls.typename, sublist) class ColumnToCollectionType(_ParameterizedType): @@ -987,18 +1122,17 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ - typename = "'org.apache.cassandra.db.marshal.ColumnToCollectionType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" class ReversedType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.ReversedType'" + typename = "org.apache.cassandra.db.marshal.ReversedType" num_subtypes = 1 @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes - return subtype.from_binary(byts) + return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): @@ -1013,7 +1147,7 @@ class FrozenType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes - return subtype.from_binary(byts) + return subtype.from_binary(byts, protocol_version) @classmethod def serialize_safe(cls, val, protocol_version): @@ -1022,7 +1156,7 @@ def serialize_safe(cls, val, protocol_version): def is_counter_type(t): - if isinstance(t, six.string_types): + if isinstance(t, str): t = lookup_casstype(t) return issubclass(t, CounterColumnType) @@ -1039,3 +1173,327 @@ def cql_typename(casstypename): 'list' """ return lookup_casstype(casstypename).cql_parameterized_type() + + +class WKBGeometryType(object): + POINT = 1 + LINESTRING = 2 + POLYGON = 3 + + +class PointType(CassandraType): + typename = 'PointType' + + _type = struct.pack('[[]] + type_ = int8_unpack(byts[0:1]) + + if type_ in (BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE), + BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN)): + time0 = precision0 = None + else: + time0 = int64_unpack(byts[1:9]) + precision0 = int8_unpack(byts[9:10]) + + if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): + time1 = int64_unpack(byts[10:18]) + precision1 = int8_unpack(byts[18:19]) + else: + time1 = precision1 = None + + if time0 is not None: + date_range_bound0 = util.DateRangeBound( + time0, + cls._decode_precision(precision0) + ) + if time1 is not None: + date_range_bound1 = util.DateRangeBound( + time1, + cls._decode_precision(precision1) + ) + + if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE): + return util.DateRange(value=date_range_bound0) + if type_ == BoundKind.to_int(BoundKind.CLOSED_RANGE): + return util.DateRange(lower_bound=date_range_bound0, + upper_bound=date_range_bound1) + if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_HIGH): + return util.DateRange(lower_bound=date_range_bound0, + upper_bound=util.OPEN_BOUND) + if type_ == BoundKind.to_int(BoundKind.OPEN_RANGE_LOW): + return util.DateRange(lower_bound=util.OPEN_BOUND, + upper_bound=date_range_bound0) + if type_ == BoundKind.to_int(BoundKind.BOTH_OPEN_RANGE): + return util.DateRange(lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND) + if type_ == BoundKind.to_int(BoundKind.SINGLE_DATE_OPEN): + return util.DateRange(value=util.OPEN_BOUND) + raise ValueError('Could not deserialize %r' % (byts,)) + + @classmethod + def serialize(cls, v, protocol_version): + buf = io.BytesIO() + bound_kind, bounds = None, () + + try: + value = v.value + except AttributeError: + raise ValueError( + '%s.serialize expects an object with a value attribute; got' + '%r' % (cls.__name__, v) + ) + + if value is None: + try: + lower_bound, upper_bound = v.lower_bound, v.upper_bound + except AttributeError: + raise ValueError( + '%s.serialize expects an object with lower_bound and ' + 'upper_bound attributes; got %r' % (cls.__name__, v) + ) + if lower_bound == util.OPEN_BOUND and upper_bound == util.OPEN_BOUND: + bound_kind = BoundKind.BOTH_OPEN_RANGE + elif lower_bound == util.OPEN_BOUND: + bound_kind = BoundKind.OPEN_RANGE_LOW + bounds = (upper_bound,) + elif upper_bound == util.OPEN_BOUND: + bound_kind = BoundKind.OPEN_RANGE_HIGH + bounds = (lower_bound,) + else: + bound_kind = BoundKind.CLOSED_RANGE + bounds = lower_bound, upper_bound + else: # value is not None + if value == util.OPEN_BOUND: + bound_kind = BoundKind.SINGLE_DATE_OPEN + else: + bound_kind = BoundKind.SINGLE_DATE + bounds = (value,) + + if bound_kind is None: + raise ValueError( + 'Cannot serialize %r; could not find bound kind' % (v,) + ) + + buf.write(int8_pack(BoundKind.to_int(bound_kind))) + for bound in bounds: + buf.write(int64_pack(bound.milliseconds)) + buf.write(int8_pack(cls._encode_precision(bound.precision))) + + return buf.getvalue() + +class VectorType(_CassandraType): + typename = 'org.apache.cassandra.db.marshal.VectorType' + vector_size = 0 + subtype = None + + @classmethod + def serial_size(cls): + serialized_size = cls.subtype.serial_size() + return cls.vector_size * serialized_size if serialized_size is not None else None + + @classmethod + def apply_parameters(cls, params, names): + assert len(params) == 2 + subtype = lookup_casstype(params[0]) + vsize = params[1] + return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) + + @classmethod + def deserialize(cls, byts, protocol_version): + serialized_size = cls.subtype.serial_size() + if serialized_size is not None: + expected_byte_size = serialized_size * cls.vector_size + if len(byts) != expected_byte_size: + raise ValueError( + "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ + .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) + indexes = (serialized_size * x for x in range(0, cls.vector_size)) + return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + + idx = 0 + rv = [] + while (len(rv) < cls.vector_size): + try: + size, bytes_read = uvint_unpack(byts[idx:]) + idx += bytes_read + rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + idx += size + except: + raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ + .format(len(rv))) + + # If we have any additional data in the serialized vector treat that as an error as well + if idx < len(byts): + raise ValueError("Additional bytes remaining after vector deserialization completed") + return rv + + @classmethod + def serialize(cls, v, protocol_version): + v_length = len(v) + if cls.vector_size != v_length: + raise ValueError( + "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\ + .format(cls.vector_size, cls.subtype.typename, v_length)) + + serialized_size = cls.subtype.serial_size() + buf = io.BytesIO() + for item in v: + item_bytes = cls.subtype.serialize(item, protocol_version) + if serialized_size is None: + buf.write(uvint_pack(len(item_bytes))) + buf.write(item_bytes) + return buf.getvalue() + + @classmethod + def cql_parameterized_type(cls): + return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size) diff --git a/cassandra/cython_marshal.pyx b/cassandra/cython_marshal.pyx index 61b6daccc1..4733a47935 100644 --- a/cassandra/cython_marshal.pyx +++ b/cassandra/cython_marshal.pyx @@ -1,12 +1,14 @@ # -- cython: profile=True # -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,8 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t) from libc.string cimport memcpy @@ -24,8 +24,6 @@ from cassandra.buffer cimport Buffer, buf_read, to_bytes cdef bint is_little_endian from cassandra.util import is_little_endian -cdef bint PY3 = six.PY3 - ctypedef fused num_t: int64_t int32_t @@ -57,10 +55,7 @@ cdef inline num_t unpack_num(Buffer *buf, num_t *dummy=NULL): # dummy pointer be cdef varint_unpack(Buffer *term): """Unpack a variable-sized integer""" - if PY3: - return varint_unpack_py3(to_bytes(term)) - else: - return varint_unpack_py2(to_bytes(term)) + return varint_unpack_py3(to_bytes(term)) # TODO: Optimize these two functions cdef varint_unpack_py3(bytes term): @@ -70,13 +65,6 @@ cdef varint_unpack_py3(bytes term): val -= 1 << shift return val -cdef varint_unpack_py2(bytes term): # noqa - val = int(term.encode('hex'), 16) - if (ord(term[0]) & 128) != 0: - shift = len(term) * 8 # * Note below - val = val - (1 << shift) - return val - # * Note * # '1 << (len(term) * 8)' Cython tries to do native # integer shifts, which overflows. We need this to diff --git a/cassandra/cython_utils.pyx b/cassandra/cython_utils.pyx index 3c6fae036b..1b6a136c69 100644 --- a/cassandra/cython_utils.pyx +++ b/cassandra/cython_utils.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/datastax/__init__.py b/cassandra/datastax/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/cassandra/datastax/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cassandra/datastax/cloud/__init__.py b/cassandra/datastax/cloud/__init__.py new file mode 100644 index 0000000000..e175b2928b --- /dev/null +++ b/cassandra/datastax/cloud/__init__.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import json +import sys +import tempfile +import shutil +from urllib.request import urlopen + +_HAS_SSL = True +try: + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED +except: + _HAS_SSL = False + +from zipfile import ZipFile + +# 2.7 vs 3.x +try: + from zipfile import BadZipFile +except: + from zipfile import BadZipfile as BadZipFile + +from cassandra import DriverException + +log = logging.getLogger(__name__) + +__all__ = ['get_cloud_config'] + +DATASTAX_CLOUD_PRODUCT_TYPE = "DATASTAX_APOLLO" + + +class CloudConfig(object): + + username = None + password = None + host = None + port = None + keyspace = None + local_dc = None + ssl_context = None + + sni_host = None + sni_port = None + host_ids = None + + @classmethod + def from_dict(cls, d): + c = cls() + + c.port = d.get('port', None) + try: + c.port = int(d['port']) + except: + pass + + c.username = d.get('username', None) + c.password = d.get('password', None) + c.host = d.get('host', None) + c.keyspace = d.get('keyspace', None) + c.local_dc = d.get('localDC', None) + + return c + + +def get_cloud_config(cloud_config, create_pyopenssl_context=False): + if not _HAS_SSL: + raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.") + + if 'secure_connect_bundle' not in cloud_config: + raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.") + + try: + config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context) + except BadZipFile: + raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.") + + config = read_metadata_info(config, cloud_config) + if create_pyopenssl_context: + config.ssl_context = config.pyopenssl_context + return config + + +def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context): + secure_bundle = cloud_config['secure_connect_bundle'] + use_default_tempdir = cloud_config.get('use_default_tempdir', None) + with ZipFile(secure_bundle) as zipfile: + base_dir = tempfile.gettempdir() if use_default_tempdir else os.path.dirname(secure_bundle) + tmp_dir = tempfile.mkdtemp(dir=base_dir) + try: + zipfile.extractall(path=tmp_dir) + return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context) + finally: + shutil.rmtree(tmp_dir) + + +def parse_cloud_config(path, cloud_config, create_pyopenssl_context): + with open(path, 'r') as stream: + data = json.load(stream) + + config = CloudConfig.from_dict(data) + config_dir = os.path.dirname(path) + + if 'ssl_context' in cloud_config: + config.ssl_context = cloud_config['ssl_context'] + else: + # Load the ssl_context before we delete the temporary directory + ca_cert_location = os.path.join(config_dir, 'ca.crt') + cert_location = os.path.join(config_dir, 'cert') + key_location = os.path.join(config_dir, 'key') + # Regardless of if we create a pyopenssl context, we still need the builtin one + # to connect to the metadata service + config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location) + if create_pyopenssl_context: + config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location) + + return config + + +def read_metadata_info(config, cloud_config): + url = "https://{}:{}/metadata".format(config.host, config.port) + timeout = cloud_config['connect_timeout'] if 'connect_timeout' in cloud_config else 5 + try: + response = urlopen(url, context=config.ssl_context, timeout=timeout) + except Exception as e: + log.exception(e) + raise DriverException("Unable to connect to the metadata service at %s. " + "Check the cluster status in the cloud console. " % url) + + if response.code != 200: + raise DriverException(("Error while fetching the metadata at: %s. " + "The service returned error code %d." % (url, response.code))) + return parse_metadata_info(config, response.read().decode('utf-8')) + + +def parse_metadata_info(config, http_data): + try: + data = json.loads(http_data) + except: + msg = "Failed to load cluster metadata" + raise DriverException(msg) + + contact_info = data['contact_info'] + config.local_dc = contact_info['local_dc'] + + proxy_info = contact_info['sni_proxy_address'].split(':') + config.sni_host = proxy_info[0] + try: + config.sni_port = int(proxy_info[1]) + except: + config.sni_port = 9042 + + config.host_ids = [host_id for host_id in contact_info['contact_points']] + + return config + + +def _ssl_context_from_cert(ca_cert_location, cert_location, key_location): + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations(ca_cert_location) + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location) + + return ssl_context + + +def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location): + try: + from OpenSSL import SSL + except ImportError as e: + raise ImportError( + "PyOpenSSL must be installed to connect to Astra with the Eventlet or Twisted event loops")\ + .with_traceback(e.__traceback__) + ssl_context = SSL.Context(SSL.TLSv1_METHOD) + ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok) + ssl_context.use_certificate_file(cert_location) + ssl_context.use_privatekey_file(key_location) + ssl_context.load_verify_locations(ca_cert_location) + + return ssl_context \ No newline at end of file diff --git a/cassandra/datastax/graph/__init__.py b/cassandra/datastax/graph/__init__.py new file mode 100644 index 0000000000..8315843a36 --- /dev/null +++ b/cassandra/datastax/graph/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cassandra.datastax.graph.types import Element, Vertex, VertexProperty, Edge, Path, T +from cassandra.datastax.graph.query import ( + GraphOptions, GraphProtocol, GraphStatement, SimpleGraphStatement, Result, + graph_object_row_factory, single_object_row_factory, + graph_result_row_factory, graph_graphson2_row_factory, + graph_graphson3_row_factory +) +from cassandra.datastax.graph.graphson import * diff --git a/cassandra/datastax/graph/fluent/__init__.py b/cassandra/datastax/graph/fluent/__init__.py new file mode 100644 index 0000000000..0dfd5230e5 --- /dev/null +++ b/cassandra/datastax/graph/fluent/__init__.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import copy + +from concurrent.futures import Future + +HAVE_GREMLIN = False +try: + import gremlin_python + HAVE_GREMLIN = True +except ImportError: + # gremlinpython is not installed. + pass + +if HAVE_GREMLIN: + from gremlin_python.structure.graph import Graph + from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal + from gremlin_python.process.traversal import Traverser, TraversalSideEffects + from gremlin_python.process.graph_traversal import GraphTraversal + + from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphOptions, GraphProtocol + from cassandra.datastax.graph.query import _GraphSONContextRowFactory + + from cassandra.datastax.graph.fluent.serializers import ( + GremlinGraphSONReaderV2, + GremlinGraphSONReaderV3, + dse_graphson2_deserializers, + gremlin_graphson2_deserializers, + dse_graphson3_deserializers, + gremlin_graphson3_deserializers + ) + from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal + + log = logging.getLogger(__name__) + + __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory', + 'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph'] + + # Traversal result keys + _bulk_key = 'bulk' + _result_key = 'result' + + + class BaseGraphRowFactory(_GraphSONContextRowFactory): + """ + Base row factory for graph traversal. This class basically wraps a + graphson reader function to handle additional features of Gremlin/DSE + and is callable as a normal row factory. + + Currently supported: + - bulk results + """ + + def __call__(self, column_names, rows): + for row in rows: + parsed_row = self.graphson_reader.readObject(row[0]) + yield parsed_row[_result_key] + bulk = parsed_row.get(_bulk_key, 1) + for _ in range(bulk - 1): + yield copy.deepcopy(parsed_row[_result_key]) + + + class _GremlinGraphSON2RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2.""" + graphson_reader_class = GremlinGraphSONReaderV2 + graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers} + + + class _DseGraphSON2RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2 as DSE types.""" + graphson_reader_class = GremlinGraphSONReaderV2 + graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers} + + gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory + # TODO remove in next major + graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory + + dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory + # TODO remove in next major + graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory + + + class _GremlinGraphSON3RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson2.""" + graphson_reader_class = GremlinGraphSONReaderV3 + graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers} + + + class _DseGraphSON3RowFactory(BaseGraphRowFactory): + """Row Factory that returns the decoded graphson3 as DSE types.""" + graphson_reader_class = GremlinGraphSONReaderV3 + graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers} + + + gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory + dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory + + + class DSESessionRemoteGraphConnection(RemoteConnection): + """ + A Tinkerpop RemoteConnection to execute traversal queries on DSE. + + :param session: A DSE session + :param graph_name: (Optional) DSE Graph name. + :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. + """ + + session = None + graph_name = None + execution_profile = None + + def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + super(DSESessionRemoteGraphConnection, self).__init__(None, None) + + if not isinstance(session, Session): + raise ValueError('A DSE Session must be provided to execute graph traversal queries.') + + self.session = session + self.graph_name = graph_name + self.execution_profile = execution_profile + + @staticmethod + def _traversers_generator(traversers): + for t in traversers: + yield Traverser(t) + + def _prepare_query(self, bytecode): + ep = self.session.execution_profile_clone_update(self.execution_profile) + graph_options = ep.graph_options + graph_options.graph_name = self.graph_name or graph_options.graph_name + graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE + # We resolve the execution profile options here , to know how what gremlin factory to set + self.session._resolve_execution_profile_options(ep) + + context = None + if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0: + row_factory = gremlin_graphson2_traversal_row_factory + elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: + row_factory = gremlin_graphson3_traversal_row_factory + context = { + 'cluster': self.session.cluster, + 'graph_name': graph_options.graph_name.decode('utf-8') + } + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol)) + + ep.row_factory = row_factory + query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context) + + return query, ep + + @staticmethod + def _handle_query_results(result_set, gremlin_future): + try: + gremlin_future.set_result( + RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects()) + ) + except Exception as e: + gremlin_future.set_exception(e) + + @staticmethod + def _handle_query_error(response, gremlin_future): + gremlin_future.set_exception(response) + + def submit(self, bytecode): + # the only reason I don't use submitAsync here + # is to avoid an unuseful future wrap + query, ep = self._prepare_query(bytecode) + + traversers = self.session.execute_graph(query, execution_profile=ep) + return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects()) + + def submitAsync(self, bytecode): + query, ep = self._prepare_query(bytecode) + + # to be compatible with gremlinpython, we need to return a concurrent.futures.Future + gremlin_future = Future() + response_future = self.session.execute_graph_async(query, execution_profile=ep) + response_future.add_callback(self._handle_query_results, gremlin_future) + response_future.add_errback(self._handle_query_error, gremlin_future) + + return gremlin_future + + def __str__(self): + return "".format(self.graph_name) + + __repr__ = __str__ + + + class DseGraph(object): + """ + Dse Graph utility class for GraphTraversal construction and execution. + """ + + DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json' + """ + Graph query language, Default is 'bytecode-json' (GraphSON). + """ + + DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0 + """ + Graph query language, Default is GraphProtocol.GRAPHSON_2_0. + """ + + @staticmethod + def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None): + """ + From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`. + + :param traversal: The GraphTraversal object + :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`. + :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt). + e.g: {'cluster': cluster, 'graph_name': name} + """ + + if isinstance(traversal, GraphTraversal): + for strategy in traversal.traversal_strategies.traversal_strategies: + rc = strategy.remote_connection + if (isinstance(rc, DSESessionRemoteGraphConnection) and + rc.session or rc.graph_name or rc.execution_profile): + log.warning("GraphTraversal session, graph_name and execution_profile are " + "only taken into account when executed with TinkerPop.") + + return _query_from_traversal(traversal, graph_protocol, context) + + @staticmethod + def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, + traversal_class=None): + """ + Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided. + + :param session: (Optional) A DSE session + :param graph_name: (Optional) DSE Graph name + :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. + :param traversal_class: (Optional) The GraphTraversalSource class to use (DSL). + + .. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + c = Cluster() + session = c.connect() + + g = DseGraph.traversal_source(session, 'my_graph') + print(g.V().valueMap().toList()) + + """ + + graph = Graph() + traversal_source = graph.traversal(traversal_class) + + if session: + traversal_source = traversal_source.withRemote( + DSESessionRemoteGraphConnection(session, graph_name, execution_profile)) + + return traversal_source + + @staticmethod + def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs): + """ + Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the + cluster by using `cluster.add_execution_profile`. + + :param graph_name: The graph name + :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`. + """ + + if graph_protocol == GraphProtocol.GRAPHSON_2_0: + row_factory = dse_graphson2_traversal_row_factory + elif graph_protocol == GraphProtocol.GRAPHSON_3_0: + row_factory = dse_graphson3_traversal_row_factory + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) + + ep = GraphExecutionProfile(row_factory=row_factory, + graph_options=GraphOptions(graph_name=graph_name, + graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE, + graph_protocol=graph_protocol), + **kwargs) + return ep + + @staticmethod + def batch(*args, **kwargs): + """ + Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to + execute multiple traversals in the same transaction. + """ + return _DefaultTraversalBatch(*args, **kwargs) diff --git a/cassandra/datastax/graph/fluent/_predicates.py b/cassandra/datastax/graph/fluent/_predicates.py new file mode 100644 index 0000000000..1c7825455a --- /dev/null +++ b/cassandra/datastax/graph/fluent/_predicates.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from gremlin_python.process.traversal import P + +from cassandra.util import Distance + +__all__ = ['GeoP', 'TextDistanceP', 'Search', 'GeoUnit', 'Geo', 'CqlCollection'] + + +class GeoP(object): + + def __init__(self, operator, value, other=None): + self.operator = operator + self.value = value + self.other = other + + @staticmethod + def inside(*args, **kwargs): + return GeoP("inside", *args, **kwargs) + + def __eq__(self, other): + return isinstance(other, + self.__class__) and self.operator == other.operator and self.value == other.value and self.other == other.other + + def __repr__(self): + return self.operator + "(" + str(self.value) + ")" if self.other is None else self.operator + "(" + str( + self.value) + "," + str(self.other) + ")" + + +class TextDistanceP(object): + + def __init__(self, operator, value, distance): + self.operator = operator + self.value = value + self.distance = distance + + @staticmethod + def fuzzy(*args): + return TextDistanceP("fuzzy", *args) + + @staticmethod + def token_fuzzy(*args): + return TextDistanceP("tokenFuzzy", *args) + + @staticmethod + def phrase(*args): + return TextDistanceP("phrase", *args) + + def __eq__(self, other): + return isinstance(other, + self.__class__) and self.operator == other.operator and self.value == other.value and self.distance == other.distance + + def __repr__(self): + return self.operator + "(" + str(self.value) + "," + str(self.distance) + ")" + + +class Search(object): + + @staticmethod + def token(value): + """ + Search any instance of a certain token within the text property targeted. + :param value: the value to look for. + """ + return P('token', value) + + @staticmethod + def token_prefix(value): + """ + Search any instance of a certain token prefix withing the text property targeted. + :param value: the value to look for. + """ + return P('tokenPrefix', value) + + @staticmethod + def token_regex(value): + """ + Search any instance of the provided regular expression for the targeted property. + :param value: the value to look for. + """ + return P('tokenRegex', value) + + @staticmethod + def prefix(value): + """ + Search for a specific prefix at the beginning of the text property targeted. + :param value: the value to look for. + """ + return P('prefix', value) + + @staticmethod + def regex(value): + """ + Search for this regular expression inside the text property targeted. + :param value: the value to look for. + """ + return P('regex', value) + + @staticmethod + def fuzzy(value, distance): + """ + Search for a fuzzy string inside the text property targeted. + :param value: the value to look for. + :param distance: The distance for the fuzzy search. ie. 1, to allow a one-letter misspellings. + """ + return TextDistanceP.fuzzy(value, distance) + + @staticmethod + def token_fuzzy(value, distance): + """ + Search for a token fuzzy inside the text property targeted. + :param value: the value to look for. + :param distance: The distance for the token fuzzy search. ie. 1, to allow a one-letter misspellings. + """ + return TextDistanceP.token_fuzzy(value, distance) + + @staticmethod + def phrase(value, proximity): + """ + Search for a phrase inside the text property targeted. + :param value: the value to look for. + :param proximity: The proximity for the phrase search. ie. phrase('David Felcey', 2).. to find 'David Felcey' with up to two middle names. + """ + return TextDistanceP.phrase(value, proximity) + + +class CqlCollection(object): + + @staticmethod + def contains(value): + """ + Search for a value inside a cql list/set column. + :param value: the value to look for. + """ + return P('contains', value) + + @staticmethod + def contains_value(value): + """ + Search for a map value. + :param value: the value to look for. + """ + return P('containsValue', value) + + @staticmethod + def contains_key(value): + """ + Search for a map key. + :param value: the value to look for. + """ + return P('containsKey', value) + + @staticmethod + def entry_eq(value): + """ + Search for a map entry. + :param value: the value to look for. + """ + return P('entryEq', value) + + +class GeoUnit(object): + _EARTH_MEAN_RADIUS_KM = 6371.0087714 + _DEGREES_TO_RADIANS = math.pi / 180 + _DEG_TO_KM = _DEGREES_TO_RADIANS * _EARTH_MEAN_RADIUS_KM + _KM_TO_DEG = 1 / _DEG_TO_KM + _MILES_TO_KM = 1.609344001 + + MILES = _MILES_TO_KM * _KM_TO_DEG + KILOMETERS = _KM_TO_DEG + METERS = _KM_TO_DEG / 1000.0 + DEGREES = 1 + + +class Geo(object): + + @staticmethod + def inside(value, units=GeoUnit.DEGREES): + """ + Search any instance of geometry inside the Distance targeted. + :param value: A Distance to look for. + :param units: The units for ``value``. See GeoUnit enum. (Can also + provide an integer to use as a multiplier to convert ``value`` to + degrees.) + """ + return GeoP.inside( + value=Distance(x=value.x, y=value.y, radius=value.radius * units) + ) diff --git a/cassandra/datastax/graph/fluent/_query.py b/cassandra/datastax/graph/fluent/_query.py new file mode 100644 index 0000000000..c476653541 --- /dev/null +++ b/cassandra/datastax/graph/fluent/_query.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from cassandra.graph import SimpleGraphStatement, GraphProtocol +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT + +from gremlin_python.process.graph_traversal import GraphTraversal +from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2 +from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3 + +from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \ + dse_graphson2_serializers, dse_graphson3_serializers + +log = logging.getLogger(__name__) + + +__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch'] + + +class _GremlinGraphSONWriterAdapter(object): + + def __init__(self, context, **kwargs): + super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs) + self.context = context + self.user_types = None + + def serialize(self, value, _): + return self.toDict(value) + + def get_serializer(self, value): + serializer = None + try: + serializer = self.serializers[type(value)] + except KeyError: + for key, ser in self.serializers.items(): + if isinstance(value, key): + serializer = ser + + if self.context: + # Check if UDT + if self.user_types is None: + try: + user_types = self.context['cluster']._user_types[self.context['graph_name']] + self.user_types = dict(map(reversed, user_types.items())) + except KeyError: + self.user_types = {} + + # Custom detection to map a namedtuple to udt + if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or + (not serializer and type(value) in self.user_types)): + serializer = GremlinUserTypeIO + + if serializer: + try: + # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant) + serializer = serializer.get_specialized_serializer(value) + except AttributeError: + pass + + return serializer + + def toDict(self, obj): + serializer = self.get_serializer(obj) + return serializer.dictify(obj, self) if serializer else obj + + def definition(self, value): + serializer = self.get_serializer(value) + return serializer.definition(value, self) + + +class GremlinGraphSON2Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV2): + pass + + +class GremlinGraphSON3Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV3): + pass + + +graphson2_writer = GremlinGraphSON2Writer +graphson3_writer = GremlinGraphSON3Writer + + +def _query_from_traversal(traversal, graph_protocol, context=None): + """ + From a GraphTraversal, return a query string. + + :param traversal: The GraphTraversal object + :param graphson_protocol: The graph protocol to determine the output format. + """ + if graph_protocol == GraphProtocol.GRAPHSON_2_0: + graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers) + elif graph_protocol == GraphProtocol.GRAPHSON_3_0: + if context is None: + raise ValueError('Missing context for GraphSON3 serialization requires.') + graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers) + else: + raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) + + try: + query = graphson_writer.writeObject(traversal) + except Exception: + log.exception("Error preparing graphson traversal query:") + raise + + return query + + +class TraversalBatch(object): + """ + A `TraversalBatch` is used to execute multiple graph traversals in a + single transaction. If any traversal in the batch fails, the entire + batch will fail to apply. + + If a TraversalBatch is bounded to a DSE session, it can be executed using + `traversal_batch.execute()`. + """ + + _session = None + _execution_profile = None + + def __init__(self, session=None, execution_profile=None): + """ + :param session: (Optional) A DSE session + :param execution_profile: (Optional) The execution profile to use for the batch execution + """ + self._session = session + self._execution_profile = execution_profile + + def add(self, traversal): + """ + Add a traversal to the batch. + + :param traversal: A gremlin GraphTraversal + """ + raise NotImplementedError() + + def add_all(self, traversals): + """ + Adds a sequence of traversals to the batch. + + :param traversals: A sequence of gremlin GraphTraversal + """ + raise NotImplementedError() + + def execute(self): + """ + Execute the traversal batch if bounded to a `DSE Session`. + """ + raise NotImplementedError() + + def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0): + """ + Return the traversal batch as GraphStatement. + + :param graph_protocol: The graph protocol for the GraphSONWriter. Default is GraphProtocol.GRAPHSON_2_0. + """ + raise NotImplementedError() + + def clear(self): + """ + Clear a traversal batch for reuse. + """ + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + def __str__(self): + return u''.format(len(self)) + __repr__ = __str__ + + +class _DefaultTraversalBatch(TraversalBatch): + + _traversals = None + + def __init__(self, *args, **kwargs): + super(_DefaultTraversalBatch, self).__init__(*args, **kwargs) + self._traversals = [] + + def add(self, traversal): + if not isinstance(traversal, GraphTraversal): + raise ValueError('traversal should be a gremlin GraphTraversal') + + self._traversals.append(traversal) + return self + + def add_all(self, traversals): + for traversal in traversals: + self.add(traversal) + + def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None): + statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals] + query = u"[{0}]".format(','.join(statements)) + return SimpleGraphStatement(query) + + def execute(self): + if self._session is None: + raise ValueError('A DSE Session must be provided to execute the traversal batch.') + + execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT + graph_options = self._session.get_execution_profile(execution_profile).graph_options + context = { + 'cluster': self._session.cluster, + 'graph_name': graph_options.graph_name + } + statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \ + if graph_options.graph_protocol else self.as_graph_statement(context=context) + return self._session.execute_graph(statement, execution_profile=execution_profile) + + def clear(self): + del self._traversals[:] + + def __len__(self): + return len(self._traversals) diff --git a/cassandra/datastax/graph/fluent/_serializers.py b/cassandra/datastax/graph/fluent/_serializers.py new file mode 100644 index 0000000000..b6c705771f --- /dev/null +++ b/cassandra/datastax/graph/fluent/_serializers.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +from gremlin_python.structure.io.graphsonV2d0 import ( + GraphSONReader as GraphSONReaderV2, + GraphSONUtil as GraphSONUtil, # no difference between v2 and v3 + VertexDeserializer as VertexDeserializerV2, + VertexPropertyDeserializer as VertexPropertyDeserializerV2, + PropertyDeserializer as PropertyDeserializerV2, + EdgeDeserializer as EdgeDeserializerV2, + PathDeserializer as PathDeserializerV2 +) + +from gremlin_python.structure.io.graphsonV3d0 import ( + GraphSONReader as GraphSONReaderV3, + VertexDeserializer as VertexDeserializerV3, + VertexPropertyDeserializer as VertexPropertyDeserializerV3, + PropertyDeserializer as PropertyDeserializerV3, + EdgeDeserializer as EdgeDeserializerV3, + PathDeserializer as PathDeserializerV3 +) + +try: + from gremlin_python.structure.io.graphsonV2d0 import ( + TraversalMetricsDeserializer as TraversalMetricsDeserializerV2, + MetricsDeserializer as MetricsDeserializerV2 + ) + from gremlin_python.structure.io.graphsonV3d0 import ( + TraversalMetricsDeserializer as TraversalMetricsDeserializerV3, + MetricsDeserializer as MetricsDeserializerV3 + ) +except ImportError: + TraversalMetricsDeserializerV2 = MetricsDeserializerV2 = None + TraversalMetricsDeserializerV3 = MetricsDeserializerV3 = None + +from cassandra.graph import ( + GraphSON2Serializer, + GraphSON2Deserializer, + GraphSON3Serializer, + GraphSON3Deserializer +) +from cassandra.graph.graphson import UserTypeIO, TypeWrapperTypeIO +from cassandra.datastax.graph.fluent.predicates import GeoP, TextDistanceP +from cassandra.util import Distance + + +__all__ = ['GremlinGraphSONReader', 'GeoPSerializer', 'TextDistancePSerializer', + 'DistanceIO', 'gremlin_deserializers', 'deserializers', 'serializers', + 'GremlinGraphSONReaderV2', 'GremlinGraphSONReaderV3', 'dse_graphson2_serializers', + 'dse_graphson2_deserializers', 'dse_graphson3_serializers', 'dse_graphson3_deserializers', + 'gremlin_graphson2_deserializers', 'gremlin_graphson3_deserializers', 'GremlinUserTypeIO'] + + +class _GremlinGraphSONTypeSerializer(object): + TYPE_KEY = "@type" + VALUE_KEY = "@value" + serializer = None + + def __init__(self, serializer): + self.serializer = serializer + + def dictify(self, v, writer): + value = self.serializer.serialize(v, writer) + if self.serializer is TypeWrapperTypeIO: + graphson_base_type = v.type_io.graphson_base_type + graphson_type = v.type_io.graphson_type + else: + graphson_base_type = self.serializer.graphson_base_type + graphson_type = self.serializer.graphson_type + + if graphson_base_type is None: + out = value + else: + out = {self.TYPE_KEY: graphson_type} + if value is not None: + out[self.VALUE_KEY] = value + + return out + + def definition(self, value, writer=None): + return self.serializer.definition(value, writer) + + def get_specialized_serializer(self, value): + ser = self.serializer.get_specialized_serializer(value) + if ser is not self.serializer: + return _GremlinGraphSONTypeSerializer(ser) + return self + + +class _GremlinGraphSONTypeDeserializer(object): + + deserializer = None + + def __init__(self, deserializer): + self.deserializer = deserializer + + def objectify(self, v, reader): + return self.deserializer.deserialize(v, reader) + + +def _make_gremlin_graphson2_deserializer(graphson_type): + return _GremlinGraphSONTypeDeserializer( + GraphSON2Deserializer.get_deserializer(graphson_type.graphson_type) + ) + + +def _make_gremlin_graphson3_deserializer(graphson_type): + return _GremlinGraphSONTypeDeserializer( + GraphSON3Deserializer.get_deserializer(graphson_type.graphson_type) + ) + + +class _GremlinGraphSONReader(object): + """Gremlin GraphSONReader Adapter, required to use gremlin types""" + + context = None + + def __init__(self, context, deserializer_map=None): + self.context = context + super(_GremlinGraphSONReader, self).__init__(deserializer_map) + + def deserialize(self, obj): + return self.toObject(obj) + + +class GremlinGraphSONReaderV2(_GremlinGraphSONReader, GraphSONReaderV2): + pass + +# TODO remove next major +GremlinGraphSONReader = GremlinGraphSONReaderV2 + +class GremlinGraphSONReaderV3(_GremlinGraphSONReader, GraphSONReaderV3): + pass + + +class GeoPSerializer(object): + @classmethod + def dictify(cls, p, writer): + out = { + "predicateType": "Geo", + "predicate": p.operator, + "value": [writer.toDict(p.value), writer.toDict(p.other)] if p.other is not None else writer.toDict(p.value) + } + return GraphSONUtil.typedValue("P", out, prefix='dse') + + +class TextDistancePSerializer(object): + @classmethod + def dictify(cls, p, writer): + out = { + "predicate": p.operator, + "value": { + 'query': writer.toDict(p.value), + 'distance': writer.toDict(p.distance) + } + } + return GraphSONUtil.typedValue("P", out) + + +class DistanceIO(object): + @classmethod + def dictify(cls, v, _): + return GraphSONUtil.typedValue('Distance', str(v), prefix='dse') + + +GremlinUserTypeIO = _GremlinGraphSONTypeSerializer(UserTypeIO) + +# GraphSON2 +dse_graphson2_serializers = OrderedDict([ + (t, _GremlinGraphSONTypeSerializer(s)) + for t, s in GraphSON2Serializer.get_type_definitions().items() +]) + +dse_graphson2_serializers.update(OrderedDict([ + (Distance, DistanceIO), + (GeoP, GeoPSerializer), + (TextDistanceP, TextDistancePSerializer) +])) + +# TODO remove next major, this is just in case someone was using it +serializers = dse_graphson2_serializers + +dse_graphson2_deserializers = { + k: _make_gremlin_graphson2_deserializer(v) + for k, v in GraphSON2Deserializer.get_type_definitions().items() +} + +dse_graphson2_deserializers.update({ + "dse:Distance": DistanceIO, +}) + +# TODO remove next major, this is just in case someone was using it +deserializers = dse_graphson2_deserializers + +gremlin_graphson2_deserializers = dse_graphson2_deserializers.copy() +gremlin_graphson2_deserializers.update({ + 'g:Vertex': VertexDeserializerV2, + 'g:VertexProperty': VertexPropertyDeserializerV2, + 'g:Edge': EdgeDeserializerV2, + 'g:Property': PropertyDeserializerV2, + 'g:Path': PathDeserializerV2 +}) + +if TraversalMetricsDeserializerV2: + gremlin_graphson2_deserializers.update({ + 'g:TraversalMetrics': TraversalMetricsDeserializerV2, + 'g:lMetrics': MetricsDeserializerV2 + }) + +# TODO remove next major, this is just in case someone was using it +gremlin_deserializers = gremlin_graphson2_deserializers + +# GraphSON3 +dse_graphson3_serializers = OrderedDict([ + (t, _GremlinGraphSONTypeSerializer(s)) + for t, s in GraphSON3Serializer.get_type_definitions().items() +]) + +dse_graphson3_serializers.update(OrderedDict([ + (Distance, DistanceIO), + (GeoP, GeoPSerializer), + (TextDistanceP, TextDistancePSerializer) +])) + +dse_graphson3_deserializers = { + k: _make_gremlin_graphson3_deserializer(v) + for k, v in GraphSON3Deserializer.get_type_definitions().items() +} + +dse_graphson3_deserializers.update({ + "dse:Distance": DistanceIO +}) + +gremlin_graphson3_deserializers = dse_graphson3_deserializers.copy() +gremlin_graphson3_deserializers.update({ + 'g:Vertex': VertexDeserializerV3, + 'g:VertexProperty': VertexPropertyDeserializerV3, + 'g:Edge': EdgeDeserializerV3, + 'g:Property': PropertyDeserializerV3, + 'g:Path': PathDeserializerV3 +}) + +if TraversalMetricsDeserializerV3: + gremlin_graphson3_deserializers.update({ + 'g:TraversalMetrics': TraversalMetricsDeserializerV3, + 'g:Metrics': MetricsDeserializerV3 + }) diff --git a/cassandra/datastax/graph/fluent/predicates.py b/cassandra/datastax/graph/fluent/predicates.py new file mode 100644 index 0000000000..8dca8b84ce --- /dev/null +++ b/cassandra/datastax/graph/fluent/predicates.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._predicates import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/fluent/query.py b/cassandra/datastax/graph/fluent/query.py new file mode 100644 index 0000000000..f599f2c979 --- /dev/null +++ b/cassandra/datastax/graph/fluent/query.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._query import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/fluent/serializers.py b/cassandra/datastax/graph/fluent/serializers.py new file mode 100644 index 0000000000..3c175f92d4 --- /dev/null +++ b/cassandra/datastax/graph/fluent/serializers.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import gremlin_python + from cassandra.datastax.graph.fluent._serializers import * +except ImportError: + # gremlinpython is not installed. + pass diff --git a/cassandra/datastax/graph/graphson.py b/cassandra/datastax/graph/graphson.py new file mode 100644 index 0000000000..7b284c4c26 --- /dev/null +++ b/cassandra/datastax/graph/graphson.py @@ -0,0 +1,1134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import base64 +import uuid +import re +import json +from decimal import Decimal +from collections import OrderedDict +import logging +import itertools +from functools import partial + +import ipaddress + + +from cassandra.cqltypes import cql_types_from_string +from cassandra.metadata import UserType +from cassandra.util import Polygon, Point, LineString, Duration +from cassandra.datastax.graph.types import Vertex, VertexProperty, Edge, Path, T + +__all__ = ['GraphSON1Serializer', 'GraphSON1Deserializer', 'GraphSON1TypeDeserializer', + 'GraphSON2Serializer', 'GraphSON2Deserializer', 'GraphSON2Reader', + 'GraphSON3Serializer', 'GraphSON3Deserializer', 'GraphSON3Reader', + 'to_bigint', 'to_int', 'to_double', 'to_float', 'to_smallint', + 'BooleanTypeIO', 'Int16TypeIO', 'Int32TypeIO', 'DoubleTypeIO', + 'FloatTypeIO', 'UUIDTypeIO', 'BigDecimalTypeIO', 'DurationTypeIO', 'InetTypeIO', + 'InstantTypeIO', 'LocalDateTypeIO', 'LocalTimeTypeIO', 'Int64TypeIO', 'BigIntegerTypeIO', + 'LocalDateTypeIO', 'PolygonTypeIO', 'PointTypeIO', 'LineStringTypeIO', 'BlobTypeIO', + 'GraphSON3Serializer', 'GraphSON3Deserializer', 'UserTypeIO', 'TypeWrapperTypeIO'] + +""" +Supported types: + +DSE Graph GraphSON 2.0 GraphSON 3.0 | Python Driver +------------ | -------------- | -------------- | ------------ +text | string | string | str +boolean | | | bool +bigint | g:Int64 | g:Int64 | long +int | g:Int32 | g:Int32 | int +double | g:Double | g:Double | float +float | g:Float | g:Float | float +uuid | g:UUID | g:UUID | UUID +bigdecimal | gx:BigDecimal | gx:BigDecimal | Decimal +duration | gx:Duration | N/A | timedelta (Classic graph only) +DSE Duration | N/A | dse:Duration | Duration (Core graph only) +inet | gx:InetAddress | gx:InetAddress | str (unicode), IPV4Address/IPV6Address (PY3) +timestamp | gx:Instant | gx:Instant | datetime.datetime +date | gx:LocalDate | gx:LocalDate | datetime.date +time | gx:LocalTime | gx:LocalTime | datetime.time +smallint | gx:Int16 | gx:Int16 | int +varint | gx:BigInteger | gx:BigInteger | long +date | gx:LocalDate | gx:LocalDate | Date +polygon | dse:Polygon | dse:Polygon | Polygon +point | dse:Point | dse:Point | Point +linestring | dse:Linestring | dse:LineString | LineString +blob | dse:Blob | dse:Blob | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +blob | gx:ByteBuffer | gx:ByteBuffer | bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +list | N/A | g:List | list (Core graph only) +map | N/A | g:Map | dict (Core graph only) +set | N/A | g:Set | set or list (Core graph only) + Can return a list due to numerical values returned by Java +tuple | N/A | dse:Tuple | tuple (Core graph only) +udt | N/A | dse:UDT | class or namedtuple (Core graph only) +""" + +MAX_INT32 = 2 ** 32 - 1 +MIN_INT32 = -2 ** 31 + +log = logging.getLogger(__name__) + + +class _GraphSONTypeType(type): + """GraphSONType metaclass, required to create a class property.""" + + @property + def graphson_type(cls): + return "{0}:{1}".format(cls.prefix, cls.graphson_base_type) + + +class GraphSONTypeIO(object, metaclass=_GraphSONTypeType): + """Represent a serializable GraphSON type""" + + prefix = 'g' + graphson_base_type = None + cql_type = None + + @classmethod + def definition(cls, value, writer=None): + return {'cqlType': cls.cql_type} + + @classmethod + def serialize(cls, value, writer=None): + return str(value) + + @classmethod + def deserialize(cls, value, reader=None): + return value + + @classmethod + def get_specialized_serializer(cls, value): + return cls + + +class TextTypeIO(GraphSONTypeIO): + cql_type = 'text' + + +class BooleanTypeIO(GraphSONTypeIO): + graphson_base_type = None + cql_type = 'boolean' + + @classmethod + def serialize(cls, value, writer=None): + return bool(value) + + +class IntegerTypeIO(GraphSONTypeIO): + + @classmethod + def serialize(cls, value, writer=None): + return value + + @classmethod + def get_specialized_serializer(cls, value): + if type(value) is int and (value > MAX_INT32 or value < MIN_INT32): + return Int64TypeIO + + return Int32TypeIO + + +class Int16TypeIO(IntegerTypeIO): + prefix = 'gx' + graphson_base_type = 'Int16' + cql_type = 'smallint' + + +class Int32TypeIO(IntegerTypeIO): + graphson_base_type = 'Int32' + cql_type = 'int' + + +class Int64TypeIO(IntegerTypeIO): + graphson_base_type = 'Int64' + cql_type = 'bigint' + + @classmethod + def deserialize(cls, value, reader=None): + return value + + +class FloatTypeIO(GraphSONTypeIO): + graphson_base_type = 'Float' + cql_type = 'float' + + @classmethod + def serialize(cls, value, writer=None): + return value + + @classmethod + def deserialize(cls, value, reader=None): + return float(value) + + +class DoubleTypeIO(FloatTypeIO): + graphson_base_type = 'Double' + cql_type = 'double' + + +class BigIntegerTypeIO(IntegerTypeIO): + prefix = 'gx' + graphson_base_type = 'BigInteger' + + +class LocalDateTypeIO(GraphSONTypeIO): + FORMAT = '%Y-%m-%d' + + prefix = 'gx' + graphson_base_type = 'LocalDate' + cql_type = 'date' + + @classmethod + def serialize(cls, value, writer=None): + return value.isoformat() + + @classmethod + def deserialize(cls, value, reader=None): + try: + return datetime.datetime.strptime(value, cls.FORMAT).date() + except ValueError: + # negative date + return value + + +class InstantTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'Instant' + cql_type = 'timestamp' + + @classmethod + def serialize(cls, value, writer=None): + if isinstance(value, datetime.datetime): + value = datetime.datetime(*value.utctimetuple()[:6]).replace(microsecond=value.microsecond) + else: + value = datetime.datetime.combine(value, datetime.datetime.min.time()) + + return "{0}Z".format(value.isoformat()) + + @classmethod + def deserialize(cls, value, reader=None): + try: + d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%fZ') + except ValueError: + d = datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%SZ') + return d + + +class LocalTimeTypeIO(GraphSONTypeIO): + FORMATS = [ + '%H:%M', + '%H:%M:%S', + '%H:%M:%S.%f' + ] + + prefix = 'gx' + graphson_base_type = 'LocalTime' + cql_type = 'time' + + @classmethod + def serialize(cls, value, writer=None): + return value.strftime(cls.FORMATS[2]) + + @classmethod + def deserialize(cls, value, reader=None): + dt = None + for f in cls.FORMATS: + try: + dt = datetime.datetime.strptime(value, f) + break + except ValueError: + continue + + if dt is None: + raise ValueError('Unable to decode LocalTime: {0}'.format(value)) + + return dt.time() + + +class BlobTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Blob' + cql_type = 'blob' + + @classmethod + def serialize(cls, value, writer=None): + value = base64.b64encode(value) + value = value.decode('utf-8') + return value + + @classmethod + def deserialize(cls, value, reader=None): + return bytearray(base64.b64decode(value)) + + +class ByteBufferTypeIO(BlobTypeIO): + prefix = 'gx' + graphson_base_type = 'ByteBuffer' + + +class UUIDTypeIO(GraphSONTypeIO): + graphson_base_type = 'UUID' + cql_type = 'uuid' + + @classmethod + def deserialize(cls, value, reader=None): + return uuid.UUID(value) + + +class BigDecimalTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'BigDecimal' + cql_type = 'bigdecimal' + + @classmethod + def deserialize(cls, value, reader=None): + return Decimal(value) + + +class DurationTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'Duration' + cql_type = 'duration' + + _duration_regex = re.compile(r""" + ^P((?P\d+)D)? + T((?P\d+)H)? + ((?P\d+)M)? + ((?P[0-9.]+)S)?$ + """, re.VERBOSE) + _duration_format = "P{days}DT{hours}H{minutes}M{seconds}S" + + _seconds_in_minute = 60 + _seconds_in_hour = 60 * _seconds_in_minute + _seconds_in_day = 24 * _seconds_in_hour + + @classmethod + def serialize(cls, value, writer=None): + total_seconds = int(value.total_seconds()) + days, total_seconds = divmod(total_seconds, cls._seconds_in_day) + hours, total_seconds = divmod(total_seconds, cls._seconds_in_hour) + minutes, total_seconds = divmod(total_seconds, cls._seconds_in_minute) + total_seconds += value.microseconds / 1e6 + + return cls._duration_format.format( + days=int(days), hours=int(hours), minutes=int(minutes), seconds=total_seconds + ) + + @classmethod + def deserialize(cls, value, reader=None): + duration = cls._duration_regex.match(value) + if duration is None: + raise ValueError('Invalid duration: {0}'.format(value)) + + duration = {k: float(v) if v is not None else 0 + for k, v in duration.groupdict().items()} + return datetime.timedelta(days=duration['days'], hours=duration['hours'], + minutes=duration['minutes'], seconds=duration['seconds']) + + +class DseDurationTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Duration' + cql_type = 'duration' + + @classmethod + def serialize(cls, value, writer=None): + return { + 'months': value.months, + 'days': value.days, + 'nanos': value.nanoseconds + } + + @classmethod + def deserialize(cls, value, reader=None): + return Duration( + reader.deserialize(value['months']), + reader.deserialize(value['days']), + reader.deserialize(value['nanos']) + ) + + +class TypeWrapperTypeIO(GraphSONTypeIO): + + @classmethod + def definition(cls, value, writer=None): + return {'cqlType': value.type_io.cql_type} + + @classmethod + def serialize(cls, value, writer=None): + return value.type_io.serialize(value.value) + + @classmethod + def deserialize(cls, value, reader=None): + return value.type_io.deserialize(value.value) + + +class PointTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Point' + cql_type = "org.apache.cassandra.db.marshal.PointType" + + @classmethod + def deserialize(cls, value, reader=None): + return Point.from_wkt(value) + + +class LineStringTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'LineString' + cql_type = "org.apache.cassandra.db.marshal.LineStringType" + + @classmethod + def deserialize(cls, value, reader=None): + return LineString.from_wkt(value) + + +class PolygonTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Polygon' + cql_type = "org.apache.cassandra.db.marshal.PolygonType" + + @classmethod + def deserialize(cls, value, reader=None): + return Polygon.from_wkt(value) + + +class InetTypeIO(GraphSONTypeIO): + prefix = 'gx' + graphson_base_type = 'InetAddress' + cql_type = 'inet' + + +class VertexTypeIO(GraphSONTypeIO): + graphson_base_type = 'Vertex' + + @classmethod + def deserialize(cls, value, reader=None): + vertex = Vertex(id=reader.deserialize(value["id"]), + label=value["label"] if "label" in value else "vertex", + type='vertex', + properties={}) + # avoid the properties processing in Vertex.__init__ + vertex.properties = reader.deserialize(value.get('properties', {})) + return vertex + + +class VertexPropertyTypeIO(GraphSONTypeIO): + graphson_base_type = 'VertexProperty' + + @classmethod + def deserialize(cls, value, reader=None): + return VertexProperty(label=value['label'], + value=reader.deserialize(value["value"]), + properties=reader.deserialize(value.get('properties', {}))) + + +class EdgeTypeIO(GraphSONTypeIO): + graphson_base_type = 'Edge' + + @classmethod + def deserialize(cls, value, reader=None): + in_vertex = Vertex(id=reader.deserialize(value["inV"]), + label=value['inVLabel'], + type='vertex', + properties={}) + out_vertex = Vertex(id=reader.deserialize(value["outV"]), + label=value['outVLabel'], + type='vertex', + properties={}) + return Edge( + id=reader.deserialize(value["id"]), + label=value["label"] if "label" in value else "vertex", + type='edge', + properties=reader.deserialize(value.get("properties", {})), + inV=in_vertex, + inVLabel=value['inVLabel'], + outV=out_vertex, + outVLabel=value['outVLabel'] + ) + + +class PropertyTypeIO(GraphSONTypeIO): + graphson_base_type = 'Property' + + @classmethod + def deserialize(cls, value, reader=None): + return {value["key"]: reader.deserialize(value["value"])} + + +class PathTypeIO(GraphSONTypeIO): + graphson_base_type = 'Path' + + @classmethod + def deserialize(cls, value, reader=None): + labels = [set(label) for label in reader.deserialize(value['labels'])] + objects = [obj for obj in reader.deserialize(value['objects'])] + p = Path(labels, []) + p.objects = objects # avoid the object processing in Path.__init__ + return p + + +class TraversalMetricsTypeIO(GraphSONTypeIO): + graphson_base_type = 'TraversalMetrics' + + @classmethod + def deserialize(cls, value, reader=None): + return reader.deserialize(value) + + +class MetricsTypeIO(GraphSONTypeIO): + graphson_base_type = 'Metrics' + + @classmethod + def deserialize(cls, value, reader=None): + return reader.deserialize(value) + + +class JsonMapTypeIO(GraphSONTypeIO): + """In GraphSON2, dict are simply serialized as json map""" + + @classmethod + def serialize(cls, value, writer=None): + out = {} + for k, v in value.items(): + out[k] = writer.serialize(v, writer) + + return out + + +class MapTypeIO(GraphSONTypeIO): + """In GraphSON3, dict has its own type""" + + graphson_base_type = 'Map' + cql_type = 'map' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + for k, v in value.items(): + # we just need the first pair to write the def + out['definition'].append(writer.definition(k)) + out['definition'].append(writer.definition(v)) + break + return out + + @classmethod + def serialize(cls, value, writer=None): + out = [] + for k, v in value.items(): + out.append(writer.serialize(k, writer)) + out.append(writer.serialize(v, writer)) + + return out + + @classmethod + def deserialize(cls, value, reader=None): + out = {} + a, b = itertools.tee(value) + for key, val in zip( + itertools.islice(a, 0, None, 2), + itertools.islice(b, 1, None, 2) + ): + out[reader.deserialize(key)] = reader.deserialize(val) + return out + + +class ListTypeIO(GraphSONTypeIO): + """In GraphSON3, list has its own type""" + + graphson_base_type = 'List' + cql_type = 'list' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + if value: + out['definition'].append(writer.definition(value[0])) + return out + + @classmethod + def serialize(cls, value, writer=None): + return [writer.serialize(v, writer) for v in value] + + @classmethod + def deserialize(cls, value, reader=None): + return [reader.deserialize(obj) for obj in value] + + +class SetTypeIO(GraphSONTypeIO): + """In GraphSON3, set has its own type""" + + graphson_base_type = 'Set' + cql_type = 'set' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict([('cqlType', cls.cql_type)]) + out['definition'] = [] + for v in value: + # we only take into account the first value for the definition + out['definition'].append(writer.definition(v)) + break + return out + + @classmethod + def serialize(cls, value, writer=None): + return [writer.serialize(v, writer) for v in value] + + @classmethod + def deserialize(cls, value, reader=None): + lst = [reader.deserialize(obj) for obj in value] + + s = set(lst) + if len(s) != len(lst): + log.warning("Coercing g:Set to list due to numerical values returned by Java. " + "See TINKERPOP-1844 for details.") + return lst + + return s + + +class BulkSetTypeIO(GraphSONTypeIO): + graphson_base_type = "BulkSet" + + @classmethod + def deserialize(cls, value, reader=None): + out = [] + + a, b = itertools.tee(value) + for val, bulk in zip( + itertools.islice(a, 0, None, 2), + itertools.islice(b, 1, None, 2) + ): + val = reader.deserialize(val) + bulk = reader.deserialize(bulk) + for n in range(bulk): + out.append(val) + + return out + + +class TupleTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'Tuple' + cql_type = 'tuple' + + @classmethod + def definition(cls, value, writer=None): + out = OrderedDict() + out['cqlType'] = cls.cql_type + serializers = [writer.get_serializer(s) for s in value] + out['definition'] = [s.definition(v, writer) for v, s in zip(value, serializers)] + return out + + @classmethod + def serialize(cls, value, writer=None): + out = cls.definition(value, writer) + out['value'] = [writer.serialize(v, writer) for v in value] + return out + + @classmethod + def deserialize(cls, value, reader=None): + return tuple(reader.deserialize(obj) for obj in value['value']) + + +class UserTypeIO(GraphSONTypeIO): + prefix = 'dse' + graphson_base_type = 'UDT' + cql_type = 'udt' + + FROZEN_REMOVAL_REGEX = re.compile(r'frozen<"*([^"]+)"*>') + + @classmethod + def cql_types_from_string(cls, typ): + # sanitizing: remove frozen references and double quotes... + return cql_types_from_string( + re.sub(cls.FROZEN_REMOVAL_REGEX, r'\1', typ) + ) + + @classmethod + def get_udt_definition(cls, value, writer): + user_type_name = writer.user_types[type(value)] + keyspace = writer.context['graph_name'] + return writer.context['cluster'].metadata.keyspaces[keyspace].user_types[user_type_name] + + @classmethod + def is_collection(cls, typ): + return typ in ['list', 'tuple', 'map', 'set'] + + @classmethod + def is_udt(cls, typ, writer): + keyspace = writer.context['graph_name'] + if keyspace in writer.context['cluster'].metadata.keyspaces: + return typ in writer.context['cluster'].metadata.keyspaces[keyspace].user_types + return False + + @classmethod + def field_definition(cls, types, writer, name=None): + """ + Build the udt field definition. This is required when we have a complex udt type. + """ + index = -1 + out = [OrderedDict() if name is None else OrderedDict([('fieldName', name)])] + + while types: + index += 1 + typ = types.pop(0) + if index > 0: + out.append(OrderedDict()) + + if cls.is_udt(typ, writer): + keyspace = writer.context['graph_name'] + udt = writer.context['cluster'].metadata.keyspaces[keyspace].user_types[typ] + out[index].update(cls.definition(udt, writer)) + elif cls.is_collection(typ): + out[index]['cqlType'] = typ + definition = cls.field_definition(types, writer) + out[index]['definition'] = definition if isinstance(definition, list) else [definition] + else: + out[index]['cqlType'] = typ + + return out if len(out) > 1 else out[0] + + @classmethod + def definition(cls, value, writer=None): + udt = value if isinstance(value, UserType) else cls.get_udt_definition(value, writer) + return OrderedDict([ + ('cqlType', cls.cql_type), + ('keyspace', udt.keyspace), + ('name', udt.name), + ('definition', [ + cls.field_definition(cls.cql_types_from_string(typ), writer, name=name) + for name, typ in zip(udt.field_names, udt.field_types)]) + ]) + + @classmethod + def serialize(cls, value, writer=None): + udt = cls.get_udt_definition(value, writer) + out = cls.definition(value, writer) + out['value'] = [] + for name, typ in zip(udt.field_names, udt.field_types): + out['value'].append(writer.serialize(getattr(value, name), writer)) + return out + + @classmethod + def deserialize(cls, value, reader=None): + udt_class = reader.context['cluster']._user_types[value['keyspace']][value['name']] + kwargs = zip( + list(map(lambda v: v['fieldName'], value['definition'])), + [reader.deserialize(v) for v in value['value']] + ) + return udt_class(**dict(kwargs)) + + +class TTypeIO(GraphSONTypeIO): + prefix = 'g' + graphson_base_type = 'T' + + @classmethod + def deserialize(cls, value, reader=None): + return T.name_to_value[value] + + +class _BaseGraphSONSerializer(object): + + _serializers = OrderedDict() + + @classmethod + def register(cls, type, serializer): + cls._serializers[type] = serializer + + @classmethod + def get_type_definitions(cls): + return cls._serializers.copy() + + @classmethod + def get_serializer(cls, value): + """ + Get the serializer for a python object. + + :param value: The python object. + """ + + # The serializer matching logic is as follow: + # 1. Try to find the python type by direct access. + # 2. Try to find the first serializer by class inheritance. + # 3. If no serializer found, return the raw value. + + # Note that when trying to find the serializer by class inheritance, + # the order that serializers are registered is important. The use of + # an OrderedDict is to avoid the difference between executions. + serializer = None + try: + serializer = cls._serializers[type(value)] + except KeyError: + for key, serializer_ in cls._serializers.items(): + if isinstance(value, key): + serializer = serializer_ + break + + if serializer: + # A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant) + serializer = serializer.get_specialized_serializer(value) + + return serializer + + @classmethod + def serialize(cls, value, writer=None): + """ + Serialize a python object to GraphSON. + + e.g 'P42DT10H5M37S' + e.g. {'key': value} + + :param value: The python object to serialize. + :param writer: A graphson serializer for recursive types (Optional) + """ + serializer = cls.get_serializer(value) + if serializer: + return serializer.serialize(value, writer or cls) + + return value + + +class GraphSON1Serializer(_BaseGraphSONSerializer): + """ + Serialize python objects to graphson types. + """ + + # When we fall back to a superclass's serializer, we iterate over this map. + # We want that iteration order to be consistent, so we use an OrderedDict, + # not a dict. + _serializers = OrderedDict([ + (str, TextTypeIO), + (bool, BooleanTypeIO), + (bytearray, ByteBufferTypeIO), + (Decimal, BigDecimalTypeIO), + (datetime.date, LocalDateTypeIO), + (datetime.time, LocalTimeTypeIO), + (datetime.timedelta, DurationTypeIO), + (datetime.datetime, InstantTypeIO), + (uuid.UUID, UUIDTypeIO), + (Polygon, PolygonTypeIO), + (Point, PointTypeIO), + (LineString, LineStringTypeIO), + (dict, JsonMapTypeIO), + (float, FloatTypeIO) + ]) + + +GraphSON1Serializer.register(ipaddress.IPv4Address, InetTypeIO) +GraphSON1Serializer.register(ipaddress.IPv6Address, InetTypeIO) +GraphSON1Serializer.register(memoryview, ByteBufferTypeIO) +GraphSON1Serializer.register(bytes, ByteBufferTypeIO) + + +class _BaseGraphSONDeserializer(object): + + _deserializers = {} + + @classmethod + def get_type_definitions(cls): + return cls._deserializers.copy() + + @classmethod + def register(cls, graphson_type, serializer): + cls._deserializers[graphson_type] = serializer + + @classmethod + def get_deserializer(cls, graphson_type): + try: + return cls._deserializers[graphson_type] + except KeyError: + raise ValueError('Invalid `graphson_type` specified: {}'.format(graphson_type)) + + @classmethod + def deserialize(cls, graphson_type, value): + """ + Deserialize a `graphson_type` value to a python object. + + :param graphson_base_type: The graphson graphson_type. e.g. 'gx:Instant' + :param value: The graphson value to deserialize. + """ + return cls.get_deserializer(graphson_type).deserialize(value) + + +class GraphSON1Deserializer(_BaseGraphSONDeserializer): + """ + Deserialize graphson1 types to python objects. + """ + _TYPES = [UUIDTypeIO, BigDecimalTypeIO, InstantTypeIO, BlobTypeIO, ByteBufferTypeIO, + PointTypeIO, LineStringTypeIO, PolygonTypeIO, LocalDateTypeIO, + LocalTimeTypeIO, DurationTypeIO, InetTypeIO] + + _deserializers = { + t.graphson_type: t + for t in _TYPES + } + + @classmethod + def deserialize_date(cls, value): + return cls._deserializers[LocalDateTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_time(cls, value): + return cls._deserializers[LocalTimeTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_timestamp(cls, value): + return cls._deserializers[InstantTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_duration(cls, value): + return cls._deserializers[DurationTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_int(cls, value): + return int(value) + + deserialize_smallint = deserialize_int + + deserialize_varint = deserialize_int + + @classmethod + def deserialize_bigint(cls, value): + return cls.deserialize_int(value) + + @classmethod + def deserialize_double(cls, value): + return float(value) + + deserialize_float = deserialize_double + + @classmethod + def deserialize_uuid(cls, value): + return cls._deserializers[UUIDTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_decimal(cls, value): + return cls._deserializers[BigDecimalTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_blob(cls, value): + return cls._deserializers[ByteBufferTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_point(cls, value): + return cls._deserializers[PointTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_linestring(cls, value): + return cls._deserializers[LineStringTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_polygon(cls, value): + return cls._deserializers[PolygonTypeIO.graphson_type].deserialize(value) + + @classmethod + def deserialize_inet(cls, value): + return value + + @classmethod + def deserialize_boolean(cls, value): + return value + + +# TODO Remove in the next major +GraphSON1TypeDeserializer = GraphSON1Deserializer +GraphSON1TypeSerializer = GraphSON1Serializer + + +class GraphSON2Serializer(_BaseGraphSONSerializer): + TYPE_KEY = "@type" + VALUE_KEY = "@value" + + _serializers = GraphSON1Serializer.get_type_definitions() + + def serialize(self, value, writer=None): + """ + Serialize a type to GraphSON2. + + e.g {'@type': 'gx:Duration', '@value': 'P2DT4H'} + + :param value: The python object to serialize. + """ + serializer = self.get_serializer(value) + if not serializer: + raise ValueError("Unable to find a serializer for value of type: ".format(type(value))) + + val = serializer.serialize(value, writer or self) + if serializer is TypeWrapperTypeIO: + graphson_base_type = value.type_io.graphson_base_type + graphson_type = value.type_io.graphson_type + else: + graphson_base_type = serializer.graphson_base_type + graphson_type = serializer.graphson_type + + if graphson_base_type is None: + out = val + else: + out = {self.TYPE_KEY: graphson_type} + if val is not None: + out[self.VALUE_KEY] = val + + return out + + +GraphSON2Serializer.register(int, IntegerTypeIO) + + +class GraphSON2Deserializer(_BaseGraphSONDeserializer): + + _TYPES = GraphSON1Deserializer._TYPES + [ + Int16TypeIO, Int32TypeIO, Int64TypeIO, DoubleTypeIO, FloatTypeIO, + BigIntegerTypeIO, VertexTypeIO, VertexPropertyTypeIO, EdgeTypeIO, + PathTypeIO, PropertyTypeIO, TraversalMetricsTypeIO, MetricsTypeIO] + + _deserializers = { + t.graphson_type: t + for t in _TYPES + } + + +class GraphSON2Reader(object): + """ + GraphSON2 Reader that parse json and deserialize to python objects. + """ + + def __init__(self, context, extra_deserializer_map=None): + """ + :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize` + """ + self.context = context + self.deserializers = GraphSON2Deserializer.get_type_definitions() + if extra_deserializer_map: + self.deserializers.update(extra_deserializer_map) + + def read(self, json_data): + """ + Read and deserialize ``json_data``. + """ + return self.deserialize(json.loads(json_data)) + + def deserialize(self, obj): + """ + Deserialize GraphSON type-tagged dict values into objects mapped in self.deserializers + """ + if isinstance(obj, dict): + try: + des = self.deserializers[obj[GraphSON2Serializer.TYPE_KEY]] + return des.deserialize(obj[GraphSON2Serializer.VALUE_KEY], self) + except KeyError: + pass + # list and map are treated as normal json objs (could be isolated deserializers) + return {self.deserialize(k): self.deserialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self.deserialize(o) for o in obj] + else: + return obj + + +class TypeIOWrapper(object): + """Used to force a graphson type during serialization""" + + type_io = None + value = None + + def __init__(self, type_io, value): + self.type_io = type_io + self.value = value + + +def _wrap_value(type_io, value): + return TypeIOWrapper(type_io, value) + + +to_bigint = partial(_wrap_value, Int64TypeIO) +to_int = partial(_wrap_value, Int32TypeIO) +to_smallint = partial(_wrap_value, Int16TypeIO) +to_double = partial(_wrap_value, DoubleTypeIO) +to_float = partial(_wrap_value, FloatTypeIO) + + +class GraphSON3Serializer(GraphSON2Serializer): + + _serializers = GraphSON2Serializer.get_type_definitions() + + context = None + """A dict of the serialization context""" + + def __init__(self, context): + self.context = context + self.user_types = None + + def definition(self, value): + serializer = self.get_serializer(value) + return serializer.definition(value, self) + + def get_serializer(self, value): + """Custom get_serializer to support UDT/Tuple""" + + serializer = super(GraphSON3Serializer, self).get_serializer(value) + is_namedtuple_udt = serializer is TupleTypeIO and hasattr(value, '_fields') + if not serializer or is_namedtuple_udt: + # Check if UDT + if self.user_types is None: + try: + user_types = self.context['cluster']._user_types[self.context['graph_name']] + self.user_types = dict(map(reversed, user_types.items())) + except KeyError: + self.user_types = {} + + serializer = UserTypeIO if (is_namedtuple_udt or (type(value) in self.user_types)) else serializer + + return serializer + + +GraphSON3Serializer.register(dict, MapTypeIO) +GraphSON3Serializer.register(list, ListTypeIO) +GraphSON3Serializer.register(set, SetTypeIO) +GraphSON3Serializer.register(tuple, TupleTypeIO) +GraphSON3Serializer.register(Duration, DseDurationTypeIO) +GraphSON3Serializer.register(TypeIOWrapper, TypeWrapperTypeIO) + + +class GraphSON3Deserializer(GraphSON2Deserializer): + _TYPES = GraphSON2Deserializer._TYPES + [MapTypeIO, ListTypeIO, + SetTypeIO, TupleTypeIO, + UserTypeIO, DseDurationTypeIO, + TTypeIO, BulkSetTypeIO] + + _deserializers = {t.graphson_type: t for t in _TYPES} + + +class GraphSON3Reader(GraphSON2Reader): + """ + GraphSON3 Reader that parse json and deserialize to python objects. + """ + + def __init__(self, context, extra_deserializer_map=None): + """ + :param context: A dict of the context, mostly used as context for udt deserialization. + :param extra_deserializer_map: map from GraphSON type tag to deserializer instance implementing `deserialize` + """ + self.context = context + self.deserializers = GraphSON3Deserializer.get_type_definitions() + if extra_deserializer_map: + self.deserializers.update(extra_deserializer_map) diff --git a/cassandra/datastax/graph/query.py b/cassandra/datastax/graph/query.py new file mode 100644 index 0000000000..d5f2a594b3 --- /dev/null +++ b/cassandra/datastax/graph/query.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from warnings import warn + +from cassandra import ConsistencyLevel +from cassandra.query import Statement, SimpleStatement +from cassandra.datastax.graph.types import Vertex, Edge, Path, VertexProperty +from cassandra.datastax.graph.graphson import GraphSON2Reader, GraphSON3Reader + + +__all__ = [ + 'GraphProtocol', 'GraphOptions', 'GraphStatement', 'SimpleGraphStatement', + 'single_object_row_factory', 'graph_result_row_factory', 'graph_object_row_factory', + 'graph_graphson2_row_factory', 'Result', 'graph_graphson3_row_factory' +] + +# (attr, description, server option) +_graph_options = ( + ('graph_name', 'name of the targeted graph.', 'graph-name'), + ('graph_source', 'choose the graph traversal source, configured on the server side.', 'graph-source'), + ('graph_language', 'the language used in the queries (default "gremlin-groovy")', 'graph-language'), + ('graph_protocol', 'the graph protocol that the server should use for query results (default "graphson-1-0")', 'graph-results'), + ('graph_read_consistency_level', '''read `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for read operations from Cassandra persistence''', 'graph-read-consistency'), + ('graph_write_consistency_level', '''write `cassandra.ConsistencyLevel `_ for graph queries (if distinct from session default). +Setting this overrides the native `Statement.consistency_level `_ for write operations to Cassandra persistence.''', 'graph-write-consistency') +) +_graph_option_names = tuple(option[0] for option in _graph_options) + +# this is defined by the execution profile attribute, not in graph options +_request_timeout_key = 'request-timeout' + + +class GraphProtocol(object): + + GRAPHSON_1_0 = b'graphson-1.0' + """ + GraphSON1 + """ + + GRAPHSON_2_0 = b'graphson-2.0' + """ + GraphSON2 + """ + + GRAPHSON_3_0 = b'graphson-3.0' + """ + GraphSON3 + """ + + +class GraphOptions(object): + """ + Options for DSE Graph Query handler. + """ + # See _graph_options map above for notes on valid options + + DEFAULT_GRAPH_PROTOCOL = GraphProtocol.GRAPHSON_1_0 + DEFAULT_GRAPH_LANGUAGE = b'gremlin-groovy' + + def __init__(self, **kwargs): + self._graph_options = {} + kwargs.setdefault('graph_source', 'g') + kwargs.setdefault('graph_language', GraphOptions.DEFAULT_GRAPH_LANGUAGE) + for attr, value in kwargs.items(): + if attr not in _graph_option_names: + warn("Unknown keyword argument received for GraphOptions: {0}".format(attr)) + setattr(self, attr, value) + + def copy(self): + new_options = GraphOptions() + new_options._graph_options = self._graph_options.copy() + return new_options + + def update(self, options): + self._graph_options.update(options._graph_options) + + def get_options_map(self, other_options=None): + """ + Returns a map for these options updated with other options, + and mapped to graph payload types. + """ + options = self._graph_options.copy() + if other_options: + options.update(other_options._graph_options) + + # cls are special-cased so they can be enums in the API, and names in the protocol + for cl in ('graph-write-consistency', 'graph-read-consistency'): + cl_enum = options.get(cl) + if cl_enum is not None: + options[cl] = ConsistencyLevel.value_to_name[cl_enum].encode() + return options + + def set_source_default(self): + """ + Sets ``graph_source`` to the server-defined default traversal source ('default') + """ + self.graph_source = 'default' + + def set_source_analytics(self): + """ + Sets ``graph_source`` to the server-defined analytic traversal source ('a') + """ + self.graph_source = 'a' + + def set_source_graph(self): + """ + Sets ``graph_source`` to the server-defined graph traversal source ('g') + """ + self.graph_source = 'g' + + def set_graph_protocol(self, protocol): + """ + Sets ``graph_protocol`` as server graph results format (See :class:`cassandra.datastax.graph.GraphProtocol`) + """ + self.graph_protocol = protocol + + @property + def is_default_source(self): + return self.graph_source in (b'default', None) + + @property + def is_analytics_source(self): + """ + True if ``graph_source`` is set to the server-defined analytics traversal source ('a') + """ + return self.graph_source == b'a' + + @property + def is_graph_source(self): + """ + True if ``graph_source`` is set to the server-defined graph traversal source ('g') + """ + return self.graph_source == b'g' + + +for opt in _graph_options: + + def get(self, key=opt[2]): + return self._graph_options.get(key) + + def set(self, value, key=opt[2]): + if value is not None: + # normalize text here so it doesn't have to be done every time we get options map + if isinstance(value, str): + value = value.encode() + self._graph_options[key] = value + else: + self._graph_options.pop(key, None) + + def delete(self, key=opt[2]): + self._graph_options.pop(key, None) + + setattr(GraphOptions, opt[0], property(get, set, delete, opt[1])) + + +class GraphStatement(Statement): + """ An abstract class representing a graph query.""" + + @property + def query(self): + raise NotImplementedError() + + def __str__(self): + return u''.format(self.query) + __repr__ = __str__ + + +class SimpleGraphStatement(GraphStatement, SimpleStatement): + """ + Simple graph statement for :meth:`.Session.execute_graph`. + Takes the same parameters as :class:`.SimpleStatement`. + """ + @property + def query(self): + return self._query_string + + +def single_object_row_factory(column_names, rows): + """ + returns the JSON string value of graph results + """ + return [row[0] for row in rows] + + +def graph_result_row_factory(column_names, rows): + """ + Returns a :class:`Result ` object that can load graph results and produce specific types. + The Result JSON is deserialized and unpacked from the top-level 'result' dict. + """ + return [Result(json.loads(row[0])['result']) for row in rows] + + +def graph_object_row_factory(column_names, rows): + """ + Like :func:`~.graph_result_row_factory`, except known element types (:class:`~.Vertex`, :class:`~.Edge`) are + converted to their simplified objects. Some low-level metadata is shed in this conversion. Unknown result types are + still returned as :class:`Result `. + """ + return _graph_object_sequence(json.loads(row[0])['result'] for row in rows) + + +def _graph_object_sequence(objects): + for o in objects: + res = Result(o) + if isinstance(o, dict): + typ = res.value.get('type') + if typ == 'vertex': + res = res.as_vertex() + elif typ == 'edge': + res = res.as_edge() + yield res + + +class _GraphSONContextRowFactory(object): + graphson_reader_class = None + graphson_reader_kwargs = None + + def __init__(self, cluster): + context = {'cluster': cluster} + kwargs = self.graphson_reader_kwargs or {} + self.graphson_reader = self.graphson_reader_class(context, **kwargs) + + def __call__(self, column_names, rows): + return [self.graphson_reader.read(row[0])['result'] for row in rows] + + +class _GraphSON2RowFactory(_GraphSONContextRowFactory): + """Row factory to deserialize GraphSON2 results.""" + graphson_reader_class = GraphSON2Reader + + +class _GraphSON3RowFactory(_GraphSONContextRowFactory): + """Row factory to deserialize GraphSON3 results.""" + graphson_reader_class = GraphSON3Reader + + +graph_graphson2_row_factory = _GraphSON2RowFactory +graph_graphson3_row_factory = _GraphSON3RowFactory + + +class Result(object): + """ + Represents deserialized graph results. + Property and item getters are provided for convenience. + """ + + value = None + """ + Deserialized value from the result + """ + + def __init__(self, value): + self.value = value + + def __getattr__(self, attr): + if not isinstance(self.value, dict): + raise ValueError("Value cannot be accessed as a dict") + + if attr in self.value: + return self.value[attr] + + raise AttributeError("Result has no top-level attribute %r" % (attr,)) + + def __getitem__(self, item): + if isinstance(self.value, dict) and isinstance(item, str): + return self.value[item] + elif isinstance(self.value, list) and isinstance(item, int): + return self.value[item] + else: + raise ValueError("Result cannot be indexed by %r" % (item,)) + + def __str__(self): + return str(self.value) + + def __repr__(self): + return "%s(%r)" % (Result.__name__, self.value) + + def __eq__(self, other): + return self.value == other.value + + def as_vertex(self): + """ + Return a :class:`Vertex` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Vertex(self.id, self.label, self.type, self.value.get('properties', {})) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Vertex from %r" % (self,)) + + def as_edge(self): + """ + Return a :class:`Edge` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Edge(self.id, self.label, self.type, self.value.get('properties', {}), + self.inV, self.inVLabel, self.outV, self.outVLabel) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Edge from %r" % (self,)) + + def as_path(self): + """ + Return a :class:`Path` parsed from this result + + Raises TypeError if parsing fails (i.e. the result structure is not valid). + """ + try: + return Path(self.labels, self.objects) + except (AttributeError, ValueError, TypeError): + raise TypeError("Could not create Path from %r" % (self,)) + + def as_vertex_property(self): + return VertexProperty(self.value.get('label'), self.value.get('value'), self.value.get('properties', {})) diff --git a/cassandra/datastax/graph/types.py b/cassandra/datastax/graph/types.py new file mode 100644 index 0000000000..75902c6622 --- /dev/null +++ b/cassandra/datastax/graph/types.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['Element', 'Vertex', 'Edge', 'VertexProperty', 'Path', 'T'] + + +class Element(object): + + element_type = None + + _attrs = ('id', 'label', 'type', 'properties') + + def __init__(self, id, label, type, properties): + if type != self.element_type: + raise TypeError("Attempted to create %s from %s element", (type, self.element_type)) + + self.id = id + self.label = label + self.type = type + self.properties = self._extract_properties(properties) + + @staticmethod + def _extract_properties(properties): + return dict(properties) + + def __eq__(self, other): + return all(getattr(self, attr) == getattr(other, attr) for attr in self._attrs) + + def __str__(self): + return str(dict((k, getattr(self, k)) for k in self._attrs)) + + +class Vertex(Element): + """ + Represents a Vertex element from a graph query. + + Vertex ``properties`` are extracted into a ``dict`` of property names to list of :class:`~VertexProperty` (list + because they are always encoded that way, and sometimes have multiple cardinality; VertexProperty because sometimes + the properties themselves have property maps). + """ + + element_type = 'vertex' + + @staticmethod + def _extract_properties(properties): + # vertex properties are always encoded as a list, regardless of Cardinality + return dict((k, [VertexProperty(k, p['value'], p.get('properties')) for p in v]) for k, v in properties.items()) + + def __repr__(self): + properties = dict((name, [{'label': prop.label, 'value': prop.value, 'properties': prop.properties} for prop in prop_list]) + for name, prop_list in self.properties.items()) + return "%s(%r, %r, %r, %r)" % (self.__class__.__name__, + self.id, self.label, + self.type, properties) + + +class VertexProperty(object): + """ + Vertex properties have a top-level value and an optional ``dict`` of properties. + """ + + label = None + """ + label of the property + """ + + value = None + """ + Value of the property + """ + + properties = None + """ + dict of properties attached to the property + """ + + def __init__(self, label, value, properties=None): + self.label = label + self.value = value + self.properties = properties or {} + + def __eq__(self, other): + return isinstance(other, VertexProperty) and self.label == other.label and self.value == other.value and self.properties == other.properties + + def __repr__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.label, self.value, self.properties) + + +class Edge(Element): + """ + Represents an Edge element from a graph query. + + Attributes match initializer parameters. + """ + + element_type = 'edge' + + _attrs = Element._attrs + ('inV', 'inVLabel', 'outV', 'outVLabel') + + def __init__(self, id, label, type, properties, + inV, inVLabel, outV, outVLabel): + super(Edge, self).__init__(id, label, type, properties) + self.inV = inV + self.inVLabel = inVLabel + self.outV = outV + self.outVLabel = outVLabel + + def __repr__(self): + return "%s(%r, %r, %r, %r, %r, %r, %r, %r)" %\ + (self.__class__.__name__, + self.id, self.label, + self.type, self.properties, + self.inV, self.inVLabel, + self.outV, self.outVLabel) + + +class Path(object): + """ + Represents a graph path. + + Labels list is taken verbatim from the results. + + Objects are either :class:`~.Result` or :class:`~.Vertex`/:class:`~.Edge` for recognized types + """ + + labels = None + """ + List of labels in the path + """ + + objects = None + """ + List of objects in the path + """ + + def __init__(self, labels, objects): + # TODO fix next major + # The Path class should not do any deserialization by itself. To fix in the next major. + from cassandra.datastax.graph.query import _graph_object_sequence + self.labels = labels + self.objects = list(_graph_object_sequence(objects)) + + def __eq__(self, other): + return self.labels == other.labels and self.objects == other.objects + + def __str__(self): + return str({'labels': self.labels, 'objects': self.objects}) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.labels, [o.value for o in self.objects]) + + +class T(object): + """ + Represents a collection of tokens for more concise Traversal definitions. + """ + + name = None + val = None + + # class attributes + id = None + """ + """ + + key = None + """ + """ + label = None + """ + """ + value = None + """ + """ + + def __init__(self, name, val): + self.name = name + self.val = val + + def __str__(self): + return self.name + + def __repr__(self): + return "T.%s" % (self.name, ) + + +T.id = T("id", 1) +T.id_ = T("id_", 2) +T.key = T("key", 3) +T.label = T("label", 4) +T.value = T("value", 5) + +T.name_to_value = { + 'id': T.id, + 'id_': T.id_, + 'key': T.key, + 'label': T.label, + 'value': T.value +} diff --git a/cassandra/datastax/insights/__init__.py b/cassandra/datastax/insights/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/cassandra/datastax/insights/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cassandra/datastax/insights/registry.py b/cassandra/datastax/insights/registry.py new file mode 100644 index 0000000000..523af4dc84 --- /dev/null +++ b/cassandra/datastax/insights/registry.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from warnings import warn + +from cassandra.datastax.insights.util import namespace + +_NOT_SET = object() + + +def _default_serializer_for_object(obj, policy): + # the insights server expects an 'options' dict for policy + # objects, but not for other objects + if policy: + return {'type': obj.__class__.__name__, + 'namespace': namespace(obj.__class__), + 'options': {}} + else: + return {'type': obj.__class__.__name__, + 'namespace': namespace(obj.__class__)} + + +class InsightsSerializerRegistry(object): + + initialized = False + + def __init__(self, mapping_dict=None): + mapping_dict = mapping_dict or {} + class_order = self._class_topological_sort(mapping_dict) + self._mapping_dict = OrderedDict( + ((cls, mapping_dict[cls]) for cls in class_order) + ) + + def serialize(self, obj, policy=False, default=_NOT_SET, cls=None): + try: + return self._get_serializer(cls if cls is not None else obj.__class__)(obj) + except Exception: + if default is _NOT_SET: + result = _default_serializer_for_object(obj, policy) + else: + result = default + + return result + + def _get_serializer(self, cls): + try: + return self._mapping_dict[cls] + except KeyError: + for registered_cls, serializer in self._mapping_dict.items(): + if issubclass(cls, registered_cls): + return self._mapping_dict[registered_cls] + raise ValueError + + def register(self, cls, serializer): + self._mapping_dict[cls] = serializer + self._mapping_dict = OrderedDict( + ((cls, self._mapping_dict[cls]) + for cls in self._class_topological_sort(self._mapping_dict)) + ) + + def register_serializer_for(self, cls): + """ + Parameterized registration helper decorator. Given a class `cls`, + produces a function that registers the decorated function as a + serializer for it. + """ + def decorator(serializer): + self.register(cls, serializer) + return serializer + + return decorator + + @staticmethod + def _class_topological_sort(classes): + """ + A simple topological sort for classes. Takes an iterable of class objects + and returns a list A of those classes, ordered such that A[X] is never a + superclass of A[Y] for X < Y. + + This is an inefficient sort, but that's ok because classes are infrequently + registered. It's more important that this be maintainable than fast. + + We can't use `.sort()` or `sorted()` with a custom `key` -- those assume + a total ordering, which we don't have. + """ + unsorted, sorted_ = list(classes), [] + while unsorted: + head, tail = unsorted[0], unsorted[1:] + + # if head has no subclasses remaining, it can safely go in the list + if not any(issubclass(x, head) for x in tail): + sorted_.append(head) + else: + # move to the back -- head has to wait until all its subclasses + # are sorted into the list + tail.append(head) + + unsorted = tail + + # check that sort is valid + for i, head in enumerate(sorted_): + for after_head_value in sorted_[(i + 1):]: + if issubclass(after_head_value, head): + warn('Sorting classes produced an invalid ordering.\n' + 'In: {classes}\n' + 'Out: {sorted_}'.format(classes=classes, sorted_=sorted_)) + return sorted_ + + +insights_registry = InsightsSerializerRegistry() diff --git a/cassandra/datastax/insights/reporter.py b/cassandra/datastax/insights/reporter.py new file mode 100644 index 0000000000..607c723a1a --- /dev/null +++ b/cassandra/datastax/insights/reporter.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter +import datetime +import json +import logging +import multiprocessing +import random +import platform +import socket +import ssl +import sys +from threading import Event, Thread +import time + +from cassandra.policies import HostDistance +from cassandra.util import ms_timestamp_from_datetime +from cassandra.datastax.insights.registry import insights_registry +from cassandra.datastax.insights.serializers import initialize_registry + +log = logging.getLogger(__name__) + + +class MonitorReporter(Thread): + + def __init__(self, interval_sec, session): + """ + takes an int indicating interval between requests, a function returning + the connection to be used, and the timeout per request + """ + # Thread is an old-style class so we can't super() + Thread.__init__(self, name='monitor_reporter') + + initialize_registry(insights_registry) + + self._interval, self._session = interval_sec, session + + self._shutdown_event = Event() + self.daemon = True + self.start() + + def run(self): + self._send_via_rpc(self._get_startup_data()) + + # introduce some jitter -- send up to 1/10 of _interval early + self._shutdown_event.wait(self._interval * random.uniform(.9, 1)) + + while not self._shutdown_event.is_set(): + start_time = time.time() + + self._send_via_rpc(self._get_status_data()) + + elapsed = time.time() - start_time + self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) + + # TODO: redundant with ConnectionHeartbeat.ShutdownException + class ShutDownException(Exception): + pass + + def _send_via_rpc(self, data): + try: + self._session.execute( + "CALL InsightsRpc.reportInsight(%s)", (json.dumps(data),) + ) + log.debug('Insights RPC data: {}'.format(data)) + except Exception as e: + log.debug('Insights RPC send failed with {}'.format(e)) + log.debug('Insights RPC data: {}'.format(data)) + + def _get_status_data(self): + cc = self._session.cluster.control_connection + + connected_nodes = { + host.address: { + 'connections': state['open_count'], + 'inFlightQueries': state['in_flights'] + } + for (host, state) in self._session.get_pool_state().items() + } + + return { + 'metadata': { + # shared across drivers; never change + 'name': 'driver.status', + # format version + 'insightMappingId': 'v1', + 'insightType': 'EVENT', + # since epoch + 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()), + 'tags': { + 'language': 'python' + } + }, + # // 'clientId', 'sessionId' and 'controlConnection' are mandatory + # // the rest of the properties are optional + 'data': { + # // 'clientId' must be the same as the one provided in the startup message + 'clientId': str(self._session.cluster.client_id), + # // 'sessionId' must be the same as the one provided in the startup message + 'sessionId': str(self._session.session_id), + 'controlConnection': cc._connection.host if cc._connection else None, + 'connectedNodes': connected_nodes + } + } + + def _get_startup_data(self): + cc = self._session.cluster.control_connection + try: + local_ipaddr = cc._connection._socket.getsockname()[0] + except Exception as e: + local_ipaddr = None + log.debug('Unable to get local socket addr from {}: {}'.format(cc._connection, e)) + hostname = socket.getfqdn() + + host_distances_counter = Counter( + self._session.cluster.profile_manager.distance(host) + for host in self._session.hosts + ) + host_distances_dict = { + 'local': host_distances_counter[HostDistance.LOCAL], + 'remote': host_distances_counter[HostDistance.REMOTE], + 'ignored': host_distances_counter[HostDistance.IGNORED] + } + + try: + compression_type = cc._connection._compression_type + except AttributeError: + compression_type = 'NONE' + + cert_validation = None + try: + if self._session.cluster.ssl_context: + if isinstance(self._session.cluster.ssl_context, ssl.SSLContext): + cert_validation = self._session.cluster.ssl_context.verify_mode == ssl.CERT_REQUIRED + else: # pyopenssl + from OpenSSL import SSL + cert_validation = self._session.cluster.ssl_context.get_verify_mode() != SSL.VERIFY_NONE + elif self._session.cluster.ssl_options: + cert_validation = self._session.cluster.ssl_options.get('cert_reqs') == ssl.CERT_REQUIRED + except Exception as e: + log.debug('Unable to get the cert validation: {}'.format(e)) + + uname_info = platform.uname() + + return { + 'metadata': { + 'name': 'driver.startup', + 'insightMappingId': 'v1', + 'insightType': 'EVENT', + 'timestamp': ms_timestamp_from_datetime(datetime.datetime.utcnow()), + 'tags': { + 'language': 'python' + }, + }, + 'data': { + 'driverName': 'DataStax Python Driver', + 'driverVersion': sys.modules['cassandra'].__version__, + 'clientId': str(self._session.cluster.client_id), + 'sessionId': str(self._session.session_id), + 'applicationName': self._session.cluster.application_name or 'python', + 'applicationNameWasGenerated': not self._session.cluster.application_name, + 'applicationVersion': self._session.cluster.application_version, + 'contactPoints': self._session.cluster._endpoint_map_for_insights, + 'dataCenters': list(set(h.datacenter for h in self._session.cluster.metadata.all_hosts() + if (h.datacenter and + self._session.cluster.profile_manager.distance(h) == HostDistance.LOCAL))), + 'initialControlConnection': cc._connection.host if cc._connection else None, + 'protocolVersion': self._session.cluster.protocol_version, + 'localAddress': local_ipaddr, + 'hostName': hostname, + 'executionProfiles': insights_registry.serialize(self._session.cluster.profile_manager), + 'configuredConnectionLength': host_distances_dict, + 'heartbeatInterval': self._session.cluster.idle_heartbeat_interval, + 'compression': compression_type.upper() if compression_type else 'NONE', + 'reconnectionPolicy': insights_registry.serialize(self._session.cluster.reconnection_policy), + 'sslConfigured': { + 'enabled': bool(self._session.cluster.ssl_options or self._session.cluster.ssl_context), + 'certValidation': cert_validation + }, + 'authProvider': { + 'type': (self._session.cluster.auth_provider.__class__.__name__ + if self._session.cluster.auth_provider else + None) + }, + 'otherOptions': { + }, + 'platformInfo': { + 'os': { + 'name': uname_info.system, + 'version': uname_info.release, + 'arch': uname_info.machine + }, + 'cpus': { + 'length': multiprocessing.cpu_count(), + 'model': platform.processor() + }, + 'runtime': { + 'python': sys.version, + 'event_loop': self._session.cluster.connection_class.__name__ + } + }, + 'periodicStatusInterval': self._interval + } + } + + def stop(self): + log.debug("Shutting down Monitor Reporter") + self._shutdown_event.set() + self.join() diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py new file mode 100644 index 0000000000..b1fe0ac5e9 --- /dev/null +++ b/cassandra/datastax/insights/serializers.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def initialize_registry(insights_registry): + # This will be called from the cluster module, so we put all this behavior + # in a function to avoid circular imports + + if insights_registry.initialized: + return False + + from cassandra import ConsistencyLevel + from cassandra.cluster import ( + ExecutionProfile, GraphExecutionProfile, + ProfileManager, ContinuousPagingOptions, + EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + _NOT_SET + ) + from cassandra.datastax.graph import GraphOptions + from cassandra.datastax.insights.registry import insights_registry + from cassandra.datastax.insights.util import namespace + from cassandra.policies import ( + RoundRobinPolicy, + DCAwareRoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, + HostFilterPolicy, + ConstantReconnectionPolicy, + ExponentialReconnectionPolicy, + RetryPolicy, + SpeculativeExecutionPolicy, + ConstantSpeculativeExecutionPolicy, + WrapperPolicy + ) + + import logging + + log = logging.getLogger(__name__) + + @insights_registry.register_serializer_for(RoundRobinPolicy) + def round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(DCAwareRoundRobinPolicy) + def dc_aware_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'local_dc': policy.local_dc, + 'used_hosts_per_remote_dc': policy.used_hosts_per_remote_dc} + } + + @insights_registry.register_serializer_for(TokenAwarePolicy) + def token_aware_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'child_policy': insights_registry.serialize(policy._child_policy, + policy=True), + 'shuffle_replicas': policy.shuffle_replicas} + } + + @insights_registry.register_serializer_for(WhiteListRoundRobinPolicy) + def whitelist_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'allowed_hosts': policy._allowed_hosts} + } + + @insights_registry.register_serializer_for(HostFilterPolicy) + def host_filter_policy_insights_serializer(policy): + return { + 'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'child_policy': insights_registry.serialize(policy._child_policy, + policy=True), + 'predicate': policy.predicate.__name__} + } + + @insights_registry.register_serializer_for(ConstantReconnectionPolicy) + def constant_reconnection_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'delay': policy.delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(ExponentialReconnectionPolicy) + def exponential_reconnection_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'base_delay': policy.base_delay, + 'max_delay': policy.max_delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(RetryPolicy) + def retry_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(SpeculativeExecutionPolicy) + def speculative_execution_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {}} + + @insights_registry.register_serializer_for(ConstantSpeculativeExecutionPolicy) + def constant_speculative_execution_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'delay': policy.delay, + 'max_attempts': policy.max_attempts} + } + + @insights_registry.register_serializer_for(WrapperPolicy) + def wrapper_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': { + 'child_policy': insights_registry.serialize(policy._child_policy, + policy=True) + }} + + @insights_registry.register_serializer_for(ExecutionProfile) + def execution_profile_insights_serializer(profile): + return { + 'loadBalancing': insights_registry.serialize(profile.load_balancing_policy, + policy=True), + 'retry': insights_registry.serialize(profile.retry_policy, + policy=True), + 'readTimeout': profile.request_timeout, + 'consistency': ConsistencyLevel.value_to_name.get(profile.consistency_level, None), + 'serialConsistency': ConsistencyLevel.value_to_name.get(profile.serial_consistency_level, None), + 'continuousPagingOptions': (insights_registry.serialize(profile.continuous_paging_options) + if (profile.continuous_paging_options is not None and + profile.continuous_paging_options is not _NOT_SET) else + None), + 'speculativeExecution': insights_registry.serialize(profile.speculative_execution_policy), + 'graphOptions': None + } + + @insights_registry.register_serializer_for(GraphExecutionProfile) + def graph_execution_profile_insights_serializer(profile): + rv = insights_registry.serialize(profile, cls=ExecutionProfile) + rv['graphOptions'] = insights_registry.serialize(profile.graph_options) + return rv + + _EXEC_PROFILE_DEFAULT_KEYS = (EXEC_PROFILE_DEFAULT, + EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT) + + @insights_registry.register_serializer_for(ProfileManager) + def profile_manager_insights_serializer(manager): + defaults = { + # Insights's expected default + 'default': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]), + # remaining named defaults for driver's defaults, including duplicated default + 'EXEC_PROFILE_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_DEFAULT]), + 'EXEC_PROFILE_GRAPH_DEFAULT': insights_registry.serialize(manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT]), + 'EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT': insights_registry.serialize( + manager.profiles[EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT] + ), + 'EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT': insights_registry.serialize( + manager.profiles[EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT] + ) + } + other = { + key: insights_registry.serialize(value) + for key, value in manager.profiles.items() + if key not in _EXEC_PROFILE_DEFAULT_KEYS + } + overlapping_keys = set(defaults) & set(other) + if overlapping_keys: + log.debug('The following key names overlap default key sentinel keys ' + 'and these non-default EPs will not be displayed in Insights ' + ': {}'.format(list(overlapping_keys))) + + other.update(defaults) + return other + + @insights_registry.register_serializer_for(GraphOptions) + def graph_options_insights_serializer(options): + rv = { + 'source': options.graph_source, + 'language': options.graph_language, + 'graphProtocol': options.graph_protocol + } + updates = {k: v.decode('utf-8') for k, v in rv.items() + if isinstance(v, bytes)} + rv.update(updates) + return rv + + @insights_registry.register_serializer_for(ContinuousPagingOptions) + def continuous_paging_options_insights_serializer(paging_options): + return { + 'page_unit': paging_options.page_unit, + 'max_pages': paging_options.max_pages, + 'max_pages_per_second': paging_options.max_pages_per_second, + 'max_queue_size': paging_options.max_queue_size + } + + insights_registry.initialized = True + return True diff --git a/cassandra/datastax/insights/util.py b/cassandra/datastax/insights/util.py new file mode 100644 index 0000000000..0ce96c7edf --- /dev/null +++ b/cassandra/datastax/insights/util.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import traceback +from warnings import warn + +from cassandra.util import Version + + +DSE_60 = Version('6.0.0') +DSE_51_MIN_SUPPORTED = Version('5.1.13') +DSE_60_MIN_SUPPORTED = Version('6.0.5') + + +log = logging.getLogger(__name__) + + +def namespace(cls): + """ + Best-effort method for getting the namespace in which a class is defined. + """ + try: + # __module__ can be None + module = cls.__module__ or '' + except Exception: + warn("Unable to obtain namespace for {cls} for Insights, returning ''. " + "Exception: \n{e}".format(e=traceback.format_exc(), cls=cls)) + module = '' + + module_internal_namespace = _module_internal_namespace_or_emtpy_string(cls) + if module_internal_namespace: + return '.'.join((module, module_internal_namespace)) + return module + + +def _module_internal_namespace_or_emtpy_string(cls): + """ + Best-effort method for getting the module-internal namespace in which a + class is defined -- i.e. the namespace _inside_ the module. + """ + try: + qualname = cls.__qualname__ + except AttributeError: + return '' + + return '.'.join( + # the last segment is the name of the class -- use everything else + qualname.split('.')[:-1] + ) + + +def version_supports_insights(dse_version): + if dse_version: + try: + dse_version = Version(dse_version) + return (DSE_51_MIN_SUPPORTED <= dse_version < DSE_60 + or + DSE_60_MIN_SUPPORTED <= dse_version) + except Exception: + warn("Unable to check version {v} for Insights compatibility, returning False. " + "Exception: \n{e}".format(e=traceback.format_exc(), v=dse_version)) + + return False diff --git a/cassandra/deserializers.pxd b/cassandra/deserializers.pxd index 0846417054..c8408a57b6 100644 --- a/cassandra/deserializers.pxd +++ b/cassandra/deserializers.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 3967ea1431..c07d67be91 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -29,8 +31,6 @@ from uuid import UUID from cassandra import cqltypes from cassandra import util -cdef bint PY2 = six.PY2 - cdef class Deserializer: """Cython-based deserializer class for a cqltype""" @@ -44,6 +44,8 @@ cdef class Deserializer: cdef class DesBytesType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return b"" return to_bytes(buf) # this is to facilitate cqlsh integration, which requires bytearrays for BytesType @@ -51,6 +53,8 @@ cdef class DesBytesType(Deserializer): # deserializers.DesBytesType = deserializers.DesBytesTypeByteArray cdef class DesBytesTypeByteArray(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return bytearray() return bytearray(buf.ptr[:buf.size]) # TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html @@ -59,7 +63,7 @@ cdef class DesDecimalType(Deserializer): cdef Buffer varint_buf slice_buffer(buf, &varint_buf, 4, buf.size - 4) - scale = unpack_num[int32_t](buf) + cdef int32_t scale = unpack_num[int32_t](buf) unscaled = varint_unpack(&varint_buf) return Decimal('%de%d' % (unscaled, -scale)) @@ -84,8 +88,8 @@ cdef class DesByteType(Deserializer): cdef class DesAsciiType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): - if PY2: - return to_bytes(buf) + if buf.size == 0: + return "" return to_bytes(buf).decode('ascii') @@ -169,6 +173,8 @@ cdef class DesTimeType(Deserializer): cdef class DesUTF8Type(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): + if buf.size == 0: + return "" cdef val = to_bytes(buf) return val.decode('utf8') diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 6d8b6ce8a2..94093e85b6 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,32 +23,22 @@ log = logging.getLogger(__name__) from binascii import hexlify +from decimal import Decimal import calendar import datetime import math import sys import types from uuid import UUID -import six +import ipaddress from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, - sortedset, Time, Date) - -if six.PY3: - long = int + sortedset, Time, Date, Point, LineString, Polygon) def cql_quote(term): - # The ordering of this method is important for the result of this method to - # be a native str type (for both Python 2 and 3) - - # Handle quoting of native str and bool types - if isinstance(term, (str, bool)): + if isinstance(term, str): return "'%s'" % str(term).replace("'", "''") - # This branch of the if statement will only be used by Python 2 to catch - # unicode strings, text_type is used to prevent type errors with Python 3. - elif isinstance(term, six.text_type): - return "'%s'" % term.encode('utf8').replace("'", "''") else: return str(term) @@ -70,6 +62,7 @@ class Encoder(object): def __init__(self): self.mapping = { float: self.cql_encode_float, + Decimal: self.cql_encode_decimal, bytearray: self.cql_encode_bytes, str: self.cql_encode_str, int: self.cql_encode_object, @@ -89,22 +82,19 @@ def __init__(self): sortedset: self.cql_encode_set_collection, frozenset: self.cql_encode_set_collection, types.GeneratorType: self.cql_encode_list_collection, - ValueSequence: self.cql_encode_sequence + ValueSequence: self.cql_encode_sequence, + Point: self.cql_encode_str_quoted, + LineString: self.cql_encode_str_quoted, + Polygon: self.cql_encode_str_quoted } - if six.PY2: - self.mapping.update({ - unicode: self.cql_encode_unicode, - buffer: self.cql_encode_bytes, - long: self.cql_encode_object, - types.NoneType: self.cql_encode_none, - }) - else: - self.mapping.update({ - memoryview: self.cql_encode_bytes, - bytes: self.cql_encode_bytes, - type(None): self.cql_encode_none, - }) + self.mapping.update({ + memoryview: self.cql_encode_bytes, + bytes: self.cql_encode_bytes, + type(None): self.cql_encode_none, + ipaddress.IPv4Address: self.cql_encode_ipaddress, + ipaddress.IPv6Address: self.cql_encode_ipaddress + }) def cql_encode_none(self, val): """ @@ -124,16 +114,11 @@ def cql_encode_str(self, val): """ return cql_quote(val) - if six.PY3: - def cql_encode_bytes(self, val): - return (b'0x' + hexlify(val)).decode('utf-8') - elif sys.version_info >= (2, 7): - def cql_encode_bytes(self, val): # noqa - return b'0x' + hexlify(val) - else: - # python 2.6 requires string or read-only buffer for hexlify - def cql_encode_bytes(self, val): # noqa - return b'0x' + hexlify(buffer(val)) + def cql_encode_str_quoted(self, val): + return "'%s'" % val + + def cql_encode_bytes(self, val): + return (b'0x' + hexlify(val)).decode('utf-8') def cql_encode_object(self, val): """ @@ -159,7 +144,7 @@ def cql_encode_datetime(self, val): with millisecond precision. """ timestamp = calendar.timegm(val.utctimetuple()) - return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) + return str(int(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) def cql_encode_date(self, val): """ @@ -204,7 +189,7 @@ def cql_encode_map_collection(self, val): return '{%s}' % ', '.join('%s: %s' % ( self.mapping.get(type(k), self.cql_encode_object)(k), self.mapping.get(type(v), self.cql_encode_object)(v) - ) for k, v in six.iteritems(val)) + ) for k, v in val.items()) def cql_encode_list_collection(self, val): """ @@ -220,9 +205,22 @@ def cql_encode_set_collection(self, val): """ return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) - def cql_encode_all_types(self, val): + def cql_encode_all_types(self, val, as_text_type=False): """ Converts any type into a CQL string, defaulting to ``cql_encode_object`` if :attr:`~Encoder.mapping` does not contain an entry for the type. """ - return self.mapping.get(type(val), self.cql_encode_object)(val) + encoded = self.mapping.get(type(val), self.cql_encode_object)(val) + if as_text_type and not isinstance(encoded, str): + return encoded.decode('utf-8') + return encoded + + def cql_encode_ipaddress(self, val): + """ + Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This + is suitable for ``inet`` type columns. + """ + return "'%s'" % val.compressed + + def cql_encode_decimal(self, val): + return self.cql_encode_float(float(val)) \ No newline at end of file diff --git a/cassandra/graph/__init__.py b/cassandra/graph/__init__.py new file mode 100644 index 0000000000..1d33345aad --- /dev/null +++ b/cassandra/graph/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph import * \ No newline at end of file diff --git a/cassandra/graph/graphson.py b/cassandra/graph/graphson.py new file mode 100644 index 0000000000..576d5063fe --- /dev/null +++ b/cassandra/graph/graphson.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.graphson import * diff --git a/cassandra/graph/query.py b/cassandra/graph/query.py new file mode 100644 index 0000000000..9003fe280f --- /dev/null +++ b/cassandra/graph/query.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.query import * diff --git a/cassandra/graph/types.py b/cassandra/graph/types.py new file mode 100644 index 0000000000..53febe7e9c --- /dev/null +++ b/cassandra/graph/types.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is only for backward compatibility when migrating from dse-driver. +from cassandra.datastax.graph.types import * diff --git a/cassandra/io/__init__.py b/cassandra/io/__init__.py index 87fc3685e0..588a655d98 100644 --- a/cassandra/io/__init__.py +++ b/cassandra/io/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py new file mode 100644 index 0000000000..95f92e26e0 --- /dev/null +++ b/cassandra/io/asyncioreactor.py @@ -0,0 +1,221 @@ +from cassandra.connection import Connection, ConnectionShutdown + +import asyncio +import logging +import os +import socket +import ssl +from threading import Lock, Thread, get_ident + + +log = logging.getLogger(__name__) + + +# This module uses ``yield from`` and ``@asyncio.coroutine`` over ``await`` and +# ``async def`` for pre-Python-3.5 compatibility, so keep in mind that the +# managed coroutines are generator-based, not native coroutines. See PEP 492: +# https://www.python.org/dev/peps/pep-0492/#coroutine-objects + + +try: + asyncio.run_coroutine_threadsafe +except AttributeError: + raise ImportError( + 'Cannot use asyncioreactor without access to ' + 'asyncio.run_coroutine_threadsafe (added in 3.4.6 and 3.5.1)' + ) + + +class AsyncioTimer(object): + """ + An ``asyncioreactor``-specific Timer. Similar to :class:`.connection.Timer, + but with a slightly different API due to limitations in the underlying + ``call_later`` interface. Not meant to be used with a + :class:`.connection.TimerManager`. + """ + + @property + def end(self): + raise NotImplementedError('{} is not compatible with TimerManager and ' + 'does not implement .end()') + + def __init__(self, timeout, callback, loop): + delayed = self._call_delayed_coro(timeout=timeout, + callback=callback) + self._handle = asyncio.run_coroutine_threadsafe(delayed, loop=loop) + + @staticmethod + async def _call_delayed_coro(timeout, callback): + await asyncio.sleep(timeout) + return callback() + + def __lt__(self, other): + try: + return self._handle < other._handle + except AttributeError: + raise NotImplemented + + def cancel(self): + self._handle.cancel() + + def finish(self): + # connection.Timer method not implemented here because we can't inspect + # the Handle returned from call_later + raise NotImplementedError('{} is not compatible with TimerManager and ' + 'does not implement .finish()') + + +class AsyncioConnection(Connection): + """ + An experimental implementation of :class:`.Connection` that uses the + ``asyncio`` module in the Python standard library for its event loop. + + Note that it requires ``asyncio`` features that were only introduced in the + 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. + """ + + _loop = None + _pid = os.getpid() + + _lock = Lock() + _loop_thread = None + + _write_queue = None + _write_queue_lock = None + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + + self._connect_socket() + self._socket.setblocking(0) + + self._write_queue = asyncio.Queue() + self._write_queue_lock = asyncio.Lock() + + # see initialize_reactor -- loop is running in a separate thread, so we + # have to use a threadsafe call + self._read_watcher = asyncio.run_coroutine_threadsafe( + self.handle_read(), loop=self._loop + ) + self._write_watcher = asyncio.run_coroutine_threadsafe( + self.handle_write(), loop=self._loop + ) + self._send_options_message() + + @classmethod + def initialize_reactor(cls): + with cls._lock: + if cls._pid != os.getpid(): + cls._loop = None + if cls._loop is None: + cls._loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls._loop) + + if not cls._loop_thread: + # daemonize so the loop will be shut down on interpreter + # shutdown + cls._loop_thread = Thread(target=cls._loop.run_forever, + daemon=True, name="asyncio_thread") + cls._loop_thread.start() + + @classmethod + def create_timer(cls, timeout, callback): + return AsyncioTimer(timeout, callback, loop=cls._loop) + + def close(self): + with self.lock: + if self.is_closed: + return + self.is_closed = True + + # close from the loop thread to avoid races when removing file + # descriptors + asyncio.run_coroutine_threadsafe( + self._close(), loop=self._loop + ) + + async def _close(self): + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) + if self._write_watcher: + self._write_watcher.cancel() + if self._read_watcher: + self._read_watcher.cancel() + if self._socket: + self._loop.remove_writer(self._socket.fileno()) + self._loop.remove_reader(self._socket.fileno()) + self._socket.close() + + log.debug("Closed socket to %s" % (self.endpoint,)) + + if not self.is_defunct: + self.error_all_requests( + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + # don't leave in-progress operations hanging + self.connected_event.set() + + def push(self, data): + buff_size = self.out_buffer_size + if len(data) > buff_size: + chunks = [] + for i in range(0, len(data), buff_size): + chunks.append(data[i:i + buff_size]) + else: + chunks = [data] + + if self._loop_thread.ident != get_ident(): + asyncio.run_coroutine_threadsafe( + self._push_msg(chunks), + loop=self._loop + ) + else: + # avoid races/hangs by just scheduling this, not using threadsafe + self._loop.create_task(self._push_msg(chunks)) + + async def _push_msg(self, chunks): + # This lock ensures all chunks of a message are sequential in the Queue + with await self._write_queue_lock: + for chunk in chunks: + self._write_queue.put_nowait(chunk) + + + async def handle_write(self): + while True: + try: + next_msg = await self._write_queue.get() + if next_msg: + await self._loop.sock_sendall(self._socket, next_msg) + except socket.error as err: + log.debug("Exception in send for %s: %s", self, err) + self.defunct(err) + return + except asyncio.CancelledError: + return + + async def handle_read(self): + while True: + try: + buf = await self._loop.sock_recv(self._socket, self.in_buffer_size) + self._iobuf.write(buf) + # sock_recv expects EWOULDBLOCK if socket provides no data, but + # nonblocking ssl sockets raise these instead, so we handle them + # ourselves by yielding to the event loop, where the socket will + # get the reading/writing it "wants" before retrying + except (ssl.SSLWantWriteError, ssl.SSLWantReadError): + # Apparently the preferred way to yield to the event loop from within + # a native coroutine based on https://github.com/python/asyncio/issues/284 + await asyncio.sleep(0) + continue + except socket.error as err: + log.debug("Exception during socket recv for %s: %s", + self, err) + self.defunct(err) + return # leave the read loop + except asyncio.CancelledError: + return + + if buf and self._iobuf.tell(): + self.process_io_buffer() + else: + log.debug("Connection %s closed by server", self) + self.close() + return diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 2a83996d17..e1bcafb39e 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,42 +20,194 @@ import os import socket import sys -from threading import Event, Lock, Thread +from threading import Lock, Thread, Event import time import weakref - -from six.moves import range +import sys +import ssl try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # noqa -import asyncore - +from cassandra import DependencyException try: - import ssl -except ImportError: - ssl = None # NOQA + import asyncore +except ModuleNotFoundError: + raise DependencyException( + "Unable to import asyncore module. Note that this module has been removed in Python 3.12 " + "so when using the driver with this version (or anything newer) you will need to use one of the " + "other event loop implementations." + ) + +from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager -from cassandra.connection import (Connection, ConnectionShutdown, - ConnectionException, NONBLOCKING, - Timer, TimerManager) log = logging.getLogger(__name__) +_dispatcher_map = {} + +def _cleanup(loop): + if loop: + loop._cleanup() + + +class WaitableTimer(Timer): + def __init__(self, timeout, callback): + Timer.__init__(self, timeout, callback) + self.callback = callback + self.event = Event() + + self.final_exception = None + + def finish(self, time_now): + try: + finished = Timer.finish(self, time_now) + if finished: + self.event.set() + return True + return False + + except Exception as e: + self.final_exception = e + self.event.set() + return True + + def wait(self, timeout=None): + self.event.wait(timeout) + if self.final_exception: + raise self.final_exception + + +class _PipeWrapper(object): + + def __init__(self, fd): + self.fd = fd + + def fileno(self): + return self.fd + + def close(self): + os.close(self.fd) + + def getsockopt(self, level, optname, buflen=None): + # act like an unerrored socket for the asyncore error handling + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError() + + +class _AsyncoreDispatcher(asyncore.dispatcher): + + def __init__(self, socket): + asyncore.dispatcher.__init__(self, map=_dispatcher_map) + # inject after to avoid base class validation + self.set_socket(socket) + self._notified = False + + def writable(self): + return False + + def validate(self): + assert not self._notified + self.notify_loop() + assert self._notified + self.loop(0.1) + assert not self._notified + + def loop(self, timeout): + asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1) + + +class _AsyncorePipeDispatcher(_AsyncoreDispatcher): + + def __init__(self): + self.read_fd, self.write_fd = os.pipe() + _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) + + def writable(self): + return False + + def handle_read(self): + while len(os.read(self.read_fd, 4096)) == 4096: + pass + self._notified = False + + def notify_loop(self): + if not self._notified: + self._notified = True + os.write(self.write_fd, b'x') + + +class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): + """ + Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because + it relies on local port binding. + Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per + instance, or this could be specialized to scan until an address is found. + + To use:: + + from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop + AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher -def _cleanup(loop_weakref): - try: - loop = loop_weakref() - except ReferenceError: - return + """ + bind_address = ('localhost', 10000) + + def __init__(self): + self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self._socket.bind(self.bind_address) + self._socket.setblocking(0) + _AsyncoreDispatcher.__init__(self, self._socket) + + def handle_read(self): + try: + d = self._socket.recvfrom(1) + while d and d[1]: + d = self._socket.recvfrom(1) + except socket.error as e: + pass + self._notified = False - loop._cleanup() + def notify_loop(self): + if not self._notified: + self._notified = True + self._socket.sendto(b'', self.bind_address) + + def loop(self, timeout): + asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) + + +class _BusyWaitDispatcher(object): + + max_write_latency = 0.001 + """ + Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check + if anything is writable. + """ + + def notify_loop(self): + pass + + def loop(self, timeout): + if not _dispatcher_map: + time.sleep(0.005) + count = timeout // self.max_write_latency + asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) + + def validate(self): + pass + + def close(self): + pass class AsyncoreLoop(object): + timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts + + _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher def __init__(self): self._pid = os.getpid() @@ -65,7 +219,15 @@ def __init__(self): self._timers = TimerManager() - atexit.register(partial(_cleanup, weakref.ref(self))) + try: + dispatcher = self._loop_dispatch_class() + dispatcher.validate() + log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) + except Exception: + log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) + dispatcher.close() + dispatcher = _BusyWaitDispatcher() + self._loop_dispatcher = dispatcher def maybe_start(self): should_start = False @@ -80,30 +242,47 @@ def maybe_start(self): self._loop_lock.release() if should_start: - self._thread = Thread(target=self._run_loop, name="cassandra_driver_event_loop") + self._thread = Thread(target=self._run_loop, name="asyncore_cassandra_driver_event_loop") self._thread.daemon = True self._thread.start() + def wake_loop(self): + self._loop_dispatcher.notify_loop() + def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: while not self._shutdown: try: - asyncore.loop(timeout=0.001, use_poll=True, count=100) + self._loop_dispatcher.loop(self.timer_resolution) self._timers.service_timeouts() - if not asyncore.socket_map: - time.sleep(0.005) - except Exception: - log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) + except Exception as exc: + self._maybe_log_debug("Asyncore event loop stopped unexpectedly", exc_info=exc) break self._started = False - log.debug("Asyncore event loop ended") + self._maybe_log_debug("Asyncore event loop ended") + + def _maybe_log_debug(self, *args, **kwargs): + try: + log.debug(*args, **kwargs) + except Exception: + # TODO: Remove when Python 2 support is removed + # PYTHON-1266. If our logger has disappeared, there's nothing we + # can do, so just log nothing. + pass def add_timer(self, timer): self._timers.add_timer(timer) + # This function is called from a different thread than the event loop + # thread, so for this call to be thread safe, we must wake up the loop + # in case it's stuck at a select + self.wake_loop() + def _cleanup(self): + global _dispatcher_map + self._shutdown = True if not self._thread: return @@ -117,6 +296,20 @@ def _cleanup(self): log.debug("Event loop thread was joined") + # Ensure all connections are closed and in-flight requests cancelled + for conn in tuple(_dispatcher_map.values()): + if conn is not self._loop_dispatcher: + conn.close() + self._timers.service_timeouts() + # Once all the connections are closed, close the dispatcher + self._loop_dispatcher.close() + + log.debug("Dispatchers were closed") + + +_global_loop = None +atexit.register(partial(_cleanup, _global_loop)) + class AsyncoreConnection(Connection, asyncore.dispatcher): """ @@ -124,67 +317,82 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): module in the Python standard library for its event loop. """ - _loop = None - _writable = False _readable = False @classmethod def initialize_reactor(cls): - if not cls._loop: - cls._loop = AsyncoreLoop() + global _global_loop + if not _global_loop: + _global_loop = AsyncoreLoop() else: current_pid = os.getpid() - if cls._loop._pid != current_pid: + if _global_loop._pid != current_pid: log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() - cls._loop = AsyncoreLoop() + _global_loop = AsyncoreLoop() @classmethod def handle_fork(cls): - if cls._loop: - cls._loop._cleanup() - cls._loop = None + global _dispatcher_map, _global_loop + _dispatcher_map = {} + if _global_loop: + _global_loop._cleanup() + _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) - cls._loop.add_timer(timer) + _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - asyncore.dispatcher.__init__(self) self.deque = deque() self.deque_lock = Lock() self._connect_socket() - asyncore.dispatcher.__init__(self, self._socket) + + # start the event loop if needed + _global_loop.maybe_start() + + init_handler = WaitableTimer( + timeout=0, + callback=partial(asyncore.dispatcher.__init__, + self, self._socket, _dispatcher_map) + ) + _global_loop.add_timer(init_handler) + init_handler.wait(kwargs["connect_timeout"]) self._writable = True self._readable = True self._send_options_message() - # start the event loop if needed - self._loop.maybe_start() - def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) self._writable = False self._readable = False - asyncore.dispatcher.close(self) - log.debug("Closed socket to %s", self.host) + + # We don't have to wait for this to be closed, we can just schedule it + self.create_timer(0, partial(asyncore.dispatcher.close, self)) + + log.debug("Closed socket to %s", self.endpoint) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) + + #This happens when the connection is shutdown while waiting for the ReadyMessage + if not self.connected_event.is_set(): + self.last_error = ConnectionShutdown("Connection to %s was closed" % self.endpoint) + # don't leave in-progress operations hanging self.connected_event.set() @@ -208,7 +416,8 @@ def handle_write(self): sent = self.send(next_msg) self._readable = True except socket.error as err: - if (err.args[0] in NONBLOCKING): + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): with self.deque_lock: self.deque.appendleft(next_msg) else: @@ -229,11 +438,17 @@ def handle_read(self): if len(buf) < self.in_buffer_size: break except socket.error as err: - if ssl and isinstance(err, ssl.SSLError): - if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if isinstance(err, ssl.SSLError): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if not self._iobuf.tell(): + return + else: self.defunct(err) return - elif err.args[0] not in NONBLOCKING: + elif err.args[0] in NONBLOCKING: + if not self._iobuf.tell(): + return + else: self.defunct(err) return @@ -254,9 +469,10 @@ def push(self, data): with self.deque_lock: self.deque.extend(chunks) self._writable = True + _global_loop.wake_loop() def writable(self): return self._writable def readable(self): - return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) + return self._readable or ((self.is_control_connection or self._continuous_paging_sessions) and not (self.is_defunct or self.is_closed)) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index dfaea8bfb4..94e1e49544 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -1,11 +1,13 @@ # Copyright 2014 Symantec Corporation -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,32 +17,32 @@ # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py - -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL import eventlet from eventlet.green import socket -import ssl from eventlet.queue import Queue +from greenlet import GreenletExit import logging -import os from threading import Event import time -from six.moves import xrange - from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager +try: + from eventlet.green.OpenSSL import SSL + _PYOPENSSL = True +except ImportError as e: + _PYOPENSSL = False + no_pyopenssl_error = e log = logging.getLogger(__name__) -def is_timeout(err): - return ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or - (err == EINVAL and os.name in ('nt', 'ce')) or - (isinstance(err, ssl.SSLError) and err.args[0] == 'timed out') or - isinstance(err, socket.timeout) - ) +def _check_pyopenssl(): + if not _PYOPENSSL: + raise ImportError( + "{}, pyOpenSSL must be installed to enable " + "SSL support with the Eventlet event loop".format(str(no_pyopenssl_error)) + ) class EventletConnection(Connection): @@ -92,7 +94,7 @@ def service_timeouts(cls): def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) - + self.uses_legacy_ssl_options = self.ssl_options and not self.ssl_context self._write_queue = Queue() self._connect_socket() @@ -101,13 +103,37 @@ def __init__(self, *args, **kwargs): self._write_watcher = eventlet.spawn(lambda: self.handle_write()) self._send_options_message() + def _wrap_socket_from_context(self): + _check_pyopenssl() + rv = SSL.Connection(self.ssl_context, self._socket) + rv.set_connect_state() + if self.ssl_options and 'server_hostname' in self.ssl_options: + # This is necessary for SNI + rv.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + return rv + + def _initiate_connection(self, sockaddr): + if self.uses_legacy_ssl_options: + super(EventletConnection, self)._initiate_connection(sockaddr) + else: + self._socket.connect(sockaddr) + if self.ssl_context or self.ssl_options: + self._socket.do_handshake() + + def _validate_hostname(self): + if not self.uses_legacy_ssl_options: + cert_name = self._socket.get_peer_certificate().get_subject().commonName + if cert_name != self.endpoint.address: + raise Exception("Hostname verification failed! Certificate name '{}' " + "doesn't match endpoint '{}'".format(cert_name, self.endpoint.address)) + def close(self): with self.lock: if self.is_closed: return self.is_closed = True - log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) cur_gthread = eventlet.getcurrent() @@ -117,11 +143,11 @@ def close(self): self._write_watcher.kill() if self._socket: self._socket.close() - log.debug("Closed socket to %s" % (self.host,)) + log.debug("Closed socket to %s" % (self.endpoint,)) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() @@ -138,6 +164,8 @@ def handle_write(self): log.debug("Exception during socket send for %s: %s", self, err) self.defunct(err) return # Leave the write loop + except GreenletExit: # graceful greenthread exit + return def handle_read(self): while True: @@ -145,14 +173,14 @@ def handle_read(self): buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: - if is_timeout(err): - continue log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop + except GreenletExit: # graceful greenthread exit + return - if self._iobuf.tell(): + if buf and self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) @@ -161,5 +189,5 @@ def handle_read(self): def push(self, data): chunk_size = self.out_buffer_size - for i in xrange(0, len(data), chunk_size): + for i in range(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 6e62a38b0e..8ad4ee99e7 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,17 +16,12 @@ import gevent import gevent.event from gevent.queue import Queue -from gevent import select, socket +from gevent import socket import gevent.ssl -from functools import partial import logging -import os import time -from six.moves import range - -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager @@ -32,14 +29,6 @@ log = logging.getLogger(__name__) -def is_timeout(err): - return ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or - (err == EINVAL and os.name in ('nt', 'ce')) or - isinstance(err, socket.timeout) - ) - - class GeventConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``gevent``. @@ -98,18 +87,18 @@ def close(self): return self.is_closed = True - log.debug("Closing connection (%s) to %s" % (id(self), self.host)) + log.debug("Closing connection (%s) to %s" % (id(self), self.endpoint)) if self._read_watcher: self._read_watcher.kill(block=False) if self._write_watcher: self._write_watcher.kill(block=False) if self._socket: self._socket.close() - log.debug("Closed socket to %s" % (self.host,)) + log.debug("Closed socket to %s" % (self.endpoint,)) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() @@ -118,48 +107,26 @@ def handle_close(self): self.close() def handle_write(self): - run_select = partial(select.select, (), (self._socket,), ()) while True: try: next_msg = self._write_queue.get() - run_select() - except Exception as exc: - if not self.is_closed: - log.debug("Exception during write select() for %s: %s", self, exc) - self.defunct(exc) - return - - try: self._socket.sendall(next_msg) except socket.error as err: - log.debug("Exception during socket sendall for %s: %s", self, err) + log.debug("Exception in send for %s: %s", self, err) self.defunct(err) - return # Leave the write loop + return def handle_read(self): - run_select = partial(select.select, (self._socket,), (), ()) while True: try: - run_select() - except Exception as exc: - if not self.is_closed: - log.debug("Exception during read select() for %s: %s", self, exc) - self.defunct(exc) - return - - try: - while True: - buf = self._socket.recv(self.in_buffer_size) - self._iobuf.write(buf) - if len(buf) < self.in_buffer_size: - break + buf = self._socket.recv(self.in_buffer_size) + self._iobuf.write(buf) except socket.error as err: - if not is_timeout(err): - log.debug("Exception during socket recv for %s: %s", self, err) - self.defunct(err) - return # leave the read loop + log.debug("Exception in read for %s: %s", self, err) + self.defunct(err) + return # leave the read loop - if self._iobuf.tell(): + if buf and self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index a3e96a9a03..275f79c374 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,33 +22,29 @@ import ssl from threading import Lock, Thread import time -import weakref - -from six.moves import range -from cassandra.connection import (Connection, ConnectionShutdown, - NONBLOCKING, Timer, TimerManager) +from cassandra import DependencyException try: import cassandra.io.libevwrapper as libev except ImportError: - raise ImportError( + raise DependencyException( "The C extension needed to use libev was not found. This " "probably means that you didn't have the required build dependencies " "when installing the driver. See " - "http://datastax.github.io/python-driver/installation.html#c-extensions " + "https://docs.datastax.com/en/developer/python-driver/latest/installation/index.html#c-extensions " "for instructions on installing build dependencies and building " "the C extension.") +from cassandra.connection import (Connection, ConnectionShutdown, + NONBLOCKING, Timer, TimerManager) + log = logging.getLogger(__name__) -def _cleanup(loop_weakref): - try: - loop = loop_weakref() - except ReferenceError: - return - loop._cleanup() +def _cleanup(loop): + if loop: + loop._cleanup() class LibevLoop(object): @@ -63,6 +61,7 @@ def __init__(self): self._started = False self._shutdown = False self._lock = Lock() + self._lock_thread = Lock() self._thread = None @@ -83,8 +82,6 @@ def __init__(self): self._timers = TimerManager() self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) - atexit.register(partial(_cleanup, weakref.ref(self))) - def maybe_start(self): should_start = False with self._lock: @@ -94,18 +91,20 @@ def maybe_start(self): should_start = True if should_start: - self._thread = Thread(target=self._run_loop, name="event_loop") - self._thread.daemon = True - self._thread.start() + with self._lock_thread: + if not self._shutdown: + self._thread = Thread(target=self._run_loop, name="event_loop") + self._thread.daemon = True + self._thread.start() self._notifier.send() def _run_loop(self): while True: - end_condition = self._loop.start() + self._loop.start() # there are still active watchers, no deadlock with self._lock: - if not self._shutdown and (end_condition or self._live_conns): + if not self._shutdown and self._live_conns: log.debug("Restarting event loop") continue else: @@ -121,21 +120,22 @@ def _cleanup(self): for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() - if conn._write_watcher: - conn._write_watcher.stop() - if conn._read_watcher: - conn._read_watcher.stop() + for watcher in (conn._write_watcher, conn._read_watcher): + if watcher: + watcher.stop() self.notify() # wake the timer watcher - log.debug("Waiting for event loop thread to join...") - self._thread.join(timeout=1.0) + + # PYTHON-752 Thread might have just been created and not started + with self._lock_thread: + self._thread.join(timeout=1.0) + if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") - self._loop = None def add_timer(self, timer): self._timers.add_timer(timer) @@ -224,11 +224,14 @@ def _loop_will_run(self, prepare): self._notifier.send() +_global_loop = None +atexit.register(partial(_cleanup, _global_loop)) + + class LibevConnection(Connection): """ An implementation of :class:`.Connection` that uses libev for its event loop. """ - _libevloop = None _write_watcher_is_active = False _read_watcher = None _write_watcher = None @@ -236,24 +239,26 @@ class LibevConnection(Connection): @classmethod def initialize_reactor(cls): - if not cls._libevloop: - cls._libevloop = LibevLoop() + global _global_loop + if not _global_loop: + _global_loop = LibevLoop() else: - if cls._libevloop._pid != os.getpid(): + if _global_loop._pid != os.getpid(): log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() - cls._libevloop = LibevLoop() + _global_loop = LibevLoop() @classmethod def handle_fork(cls): - if cls._libevloop: - cls._libevloop._cleanup() - cls._libevloop = None + global _global_loop + if _global_loop: + _global_loop._cleanup() + _global_loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) - cls._libevloop.add_timer(timer) + _global_loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): @@ -264,16 +269,16 @@ def __init__(self, *args, **kwargs): self._connect_socket() self._socket.setblocking(0) - with self._libevloop._lock: - self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, self._libevloop._loop, self.handle_read) - self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, self._libevloop._loop, self.handle_write) + with _global_loop._lock: + self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, _global_loop._loop, self.handle_read) + self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, _global_loop._loop, self.handle_write) self._send_options_message() - self._libevloop.connection_created(self) + _global_loop.connection_created(self) # start the global event loop if needed - self._libevloop.maybe_start() + _global_loop.maybe_start() def close(self): with self.lock: @@ -281,15 +286,16 @@ def close(self): return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) - self._libevloop.connection_destroyed(self) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) + + _global_loop.connection_destroyed(self) self._socket.close() - log.debug("Closed socket to %s", self.host) + log.debug("Closed socket to %s", self.endpoint) # don't leave in-progress operations hanging if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) def handle_write(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: @@ -306,12 +312,17 @@ def handle_write(self, watcher, revents, errno=None): with self._deque_lock: next_msg = self.deque.popleft() except IndexError: + if not self._socket_writable: + self._socket_writable = True return try: sent = self._socket.send(next_msg) except socket.error as err: - if (err.args[0] in NONBLOCKING): + if (err.args[0] in NONBLOCKING or + err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE)): + if err.args[0] in NONBLOCKING: + self._socket_writable = False with self._deque_lock: self.deque.appendleft(next_msg) else: @@ -321,6 +332,11 @@ def handle_write(self, watcher, revents, errno=None): if sent < len(next_msg): with self._deque_lock: self.deque.appendleft(next_msg[sent:]) + # we've seen some cases that 0 is returned instead of NONBLOCKING. But usually, + # we don't expect this to happen. https://bugs.python.org/issue20951 + if sent == 0: + self._socket_writable = False + return def handle_read(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: @@ -338,11 +354,17 @@ def handle_read(self, watcher, revents, errno=None): if len(buf) < self.in_buffer_size: break except socket.error as err: - if ssl and isinstance(err, ssl.SSLError): - if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if isinstance(err, ssl.SSLError): + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + if not self._iobuf.tell(): + return + else: self.defunct(err) return - elif err.args[0] not in NONBLOCKING: + elif err.args[0] in NONBLOCKING: + if not self._iobuf.tell(): + return + else: self.defunct(err) return @@ -363,4 +385,4 @@ def push(self, data): with self._deque_lock: self.deque.extend(chunks) - self._libevloop.notify() + _global_loop.notify() diff --git a/cassandra/io/libevwrapper.c b/cassandra/io/libevwrapper.c index 99e1df30f7..85ed551951 100644 --- a/cassandra/io/libevwrapper.c +++ b/cassandra/io/libevwrapper.c @@ -583,7 +583,6 @@ static PyMethodDef module_methods[] = { PyDoc_STRVAR(module_doc, "libev wrapper methods"); -#if PY_MAJOR_VERSION >= 3 static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "libevwrapper", @@ -600,13 +599,6 @@ static struct PyModuleDef moduledef = { PyObject * PyInit_libevwrapper(void) - -# else -# define INITERROR return - -void -initlibevwrapper(void) -#endif { PyObject *module = NULL; @@ -629,11 +621,7 @@ initlibevwrapper(void) if (PyType_Ready(&libevwrapper_TimerType) < 0) INITERROR; -# if PY_MAJOR_VERSION >= 3 module = PyModule_Create(&moduledef); -# else - module = Py_InitModule3("libevwrapper", module_methods, module_doc); -# endif if (module == NULL) INITERROR; @@ -665,11 +653,14 @@ initlibevwrapper(void) if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1) INITERROR; +#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 7) + // Since CPython 3.7, `Py_Initialize()` routing always initializes GIL. + // Routine `PyEval_ThreadsInitialized()` has been deprecated in CPython 3.7 + // and completely removed in CPython 3.13. if (!PyEval_ThreadsInitialized()) { PyEval_InitThreads(); } +#endif -#if PY_MAJOR_VERSION >= 3 return module; -#endif } diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index ccd976bd2d..b55ac4d1a3 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +18,26 @@ ( https://twistedmatrix.com ). """ import atexit -from functools import partial import logging -from threading import Thread, Lock import time -from twisted.internet import reactor, protocol +from functools import partial +from threading import Thread, Lock import weakref -from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager - - +from twisted.internet import reactor, protocol +from twisted.internet.endpoints import connectProtocol, TCP4ClientEndpoint, SSL4ClientEndpoint +from twisted.internet.interfaces import IOpenSSLClientConnectionCreator +from twisted.python.failure import Failure +from zope.interface import implementer + +from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager, ConnectionException + +try: + from OpenSSL import SSL + _HAS_SSL = True +except ImportError as e: + _HAS_SSL = False + import_exception = e log = logging.getLogger(__name__) @@ -42,6 +54,9 @@ class TwistedConnectionProtocol(protocol.Protocol): made events. """ + def __init__(self, connection): + self.connection = connection + def dataReceived(self, data): """ Callback function that is called when data has been received @@ -50,8 +65,8 @@ def dataReceived(self, data): Reaches back to the Connection object and queues the data for processing. """ - self.transport.connector.factory.conn._iobuf.write(data) - self.transport.connector.factory.conn.handle_read() + self.connection._iobuf.write(data) + self.connection.handle_read() def connectionMade(self): """ @@ -60,48 +75,12 @@ def connectionMade(self): Reaches back to the Connection object and confirms that the connection is ready. """ - self.transport.connector.factory.conn.client_connection_made() + self.connection.client_connection_made(self.transport) def connectionLost(self, reason): # reason is a Failure instance - self.transport.connector.factory.conn.defunct(reason.value) - - -class TwistedConnectionClientFactory(protocol.ClientFactory): - - def __init__(self, connection): - # ClientFactory does not define __init__() in parent classes - # and does not inherit from object. - self.conn = connection - - def buildProtocol(self, addr): - """ - Twisted function that defines which kind of protocol to use - in the ClientFactory. - """ - return TwistedConnectionProtocol() - - def clientConnectionFailed(self, connector, reason): - """ - Overridden twisted callback which is called when the - connection attempt fails. - """ - log.debug("Connect failed: %s", reason) - self.conn.defunct(reason.value) - - def clientConnectionLost(self, connector, reason): - """ - Overridden twisted callback which is called when the - connection goes away (cleanly or otherwise). - - It should be safe to call defunct() here instead of just close, because - we can assume that if the connection was closed cleanly, there are no - requests to error out. If this assumption turns out to be false, we - can call close() instead of defunct() when "reason" is an appropriate - type. - """ log.debug("Connect lost: %s", reason) - self.conn.defunct(reason.value) + self.connection.defunct(reason.value) class TwistedLoop(object): @@ -119,12 +98,15 @@ def maybe_start(self): with self._lock: if not reactor.running: self._thread = Thread(target=reactor.run, - name="cassandra_driver_event_loop", + name="cassandra_driver_twisted_event_loop", kwargs={'installSignalHandlers': False}) self._thread.daemon = True self._thread.start() atexit.register(partial(_cleanup, weakref.ref(self))) + def _reactor_stopped(self): + return reactor._stopped + def _cleanup(self): if self._thread: reactor.callFromThread(reactor.stop) @@ -157,6 +139,48 @@ def _on_loop_timer(self): self._schedule_timeout(self._timers.next_timeout) +@implementer(IOpenSSLClientConnectionCreator) +class _SSLCreator(object): + def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout): + self.endpoint = endpoint + self.ssl_options = ssl_options + self.check_hostname = check_hostname + self.timeout = timeout + + if ssl_context: + self.context = ssl_context + else: + self.context = SSL.Context(SSL.TLSv1_METHOD) + if "certfile" in self.ssl_options: + self.context.use_certificate_file(self.ssl_options["certfile"]) + if "keyfile" in self.ssl_options: + self.context.use_privatekey_file(self.ssl_options["keyfile"]) + if "ca_certs" in self.ssl_options: + self.context.load_verify_locations(self.ssl_options["ca_certs"]) + if "cert_reqs" in self.ssl_options: + self.context.set_verify( + self.ssl_options["cert_reqs"], + callback=self.verify_callback + ) + self.context.set_info_callback(self.info_callback) + + def verify_callback(self, connection, x509, errnum, errdepth, ok): + return ok + + def info_callback(self, connection, where, ret): + if where & SSL.SSL_CB_HANDSHAKE_DONE: + if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName: + transport = connection.get_app_data() + transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint))) + + def clientConnectionForTLS(self, tlsProtocol): + connection = SSL.Connection(self.context, None) + connection.set_app_data(tlsProtocol) + if self.ssl_options and "server_hostname" in self.ssl_options: + connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + return connection + + class TwistedConnection(Connection): """ An implementation of :class:`.Connection` that utilizes the @@ -189,27 +213,62 @@ def __init__(self, *args, **kwargs): self.is_closed = True self.connector = None + self.transport = None reactor.callFromThread(self.add_connection) self._loop.maybe_start() + def _check_pyopenssl(self): + if self.ssl_context or self.ssl_options: + if not _HAS_SSL: + raise ImportError( + str(import_exception) + + ', pyOpenSSL must be installed to enable SSL support with the Twisted event loop' + ) + def add_connection(self): """ Convenience function to connect and store the resulting connector. """ - self.connector = reactor.connectTCP( - host=self.host, port=self.port, - factory=TwistedConnectionClientFactory(self), - timeout=self.connect_timeout) - - def client_connection_made(self): + host, port = self.endpoint.resolve() + if self.ssl_context or self.ssl_options: + # Can't use optionsForClientTLS here because it *forces* hostname verification. + # Cool they enforce strong security, but we have to be able to turn it off + self._check_pyopenssl() + + ssl_connection_creator = _SSLCreator( + self.endpoint, + self.ssl_context if self.ssl_context else None, + self.ssl_options, + self._check_hostname, + self.connect_timeout, + ) + + endpoint = SSL4ClientEndpoint( + reactor, + host, + port, + sslContextFactory=ssl_connection_creator, + timeout=self.connect_timeout, + ) + else: + endpoint = TCP4ClientEndpoint( + reactor, + host, + port, + timeout=self.connect_timeout + ) + connectProtocol(endpoint, TwistedConnectionProtocol(self)) + + def client_connection_made(self, transport): """ Called by twisted protocol when a connection attempt has succeeded. """ with self.lock: self.is_closed = False + self.transport = transport self._send_options_message() def close(self): @@ -221,13 +280,13 @@ def close(self): return self.is_closed = True - log.debug("Closing connection (%s) to %s", id(self), self.host) - self.connector.disconnect() - log.debug("Closed socket to %s", self.host) + log.debug("Closing connection (%s) to %s", id(self), self.endpoint) + reactor.callFromThread(self.transport.connector.disconnect) + log.debug("Closed socket to %s", self.endpoint) if not self.is_defunct: self.error_all_requests( - ConnectionShutdown("Connection to %s was closed" % self.host)) + ConnectionShutdown("Connection to %s was closed" % self.endpoint)) # don't leave in-progress operations hanging self.connected_event.set() @@ -246,4 +305,4 @@ def push(self, data): it is not thread-safe, so we schedule it to run from within the event loop when it gets the chance. """ - reactor.callFromThread(self.connector.transport.write, data) + reactor.callFromThread(self.transport.write, data) diff --git a/cassandra/ioutils.pyx b/cassandra/ioutils.pyx index c59a6a0cf4..91c2bf9542 100644 --- a/cassandra/ioutils.pyx +++ b/cassandra/ioutils.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/marshal.py b/cassandra/marshal.py index 5a523d6381..e8733f0544 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six import struct @@ -28,6 +29,7 @@ def _make_packer(format_string): int8_pack, int8_unpack = _make_packer('>b') uint64_pack, uint64_unpack = _make_packer('>Q') uint32_pack, uint32_unpack = _make_packer('>I') +uint32_le_pack, uint32_le_unpack = _make_packer('H') uint8_pack, uint8_unpack = _make_packer('>B') float_pack, float_unpack = _make_packer('>f') @@ -44,28 +46,16 @@ def _make_packer(format_string): v3_header_unpack = v3_header_struct.unpack -if six.PY3: - def varint_unpack(term): - val = int(''.join("%02x" % i for i in term), 16) - if (term[0] & 128) != 0: - len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code - val -= 1 << (len_term * 8) - return val -else: - def varint_unpack(term): # noqa - val = int(term.encode('hex'), 16) - if (ord(term[0]) & 128) != 0: - len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code - val = val - (1 << (len_term * 8)) - return val +def varint_unpack(term): + val = int(''.join("%02x" % i for i in term), 16) + if (term[0] & 128) != 0: + len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code + val -= 1 << (len_term * 8) + return val -def bitlength(n): - bitlen = 0 - while n > 0: - n >>= 1 - bitlen += 1 - return bitlen +def bit_length(n): + return int.bit_length(n) def varint_pack(big): @@ -73,7 +63,7 @@ def varint_pack(big): if big == 0: return b'\x00' if big < 0: - bytelength = bitlength(abs(big) - 1) // 8 + 1 + bytelength = bit_length(abs(big) - 1) // 8 + 1 big = (1 << bytelength * 8) + big pos = False revbytes = bytearray() @@ -83,4 +73,119 @@ def varint_pack(big): if pos and revbytes[-1] & 0x80: revbytes.append(0) revbytes.reverse() - return six.binary_type(revbytes) + return bytes(revbytes) + + +point_be = struct.Struct('>dd') +point_le = struct.Struct('ddd') +circle_le = struct.Struct('> 63) + + +def decode_zig_zag(n): + return (n >> 1) ^ -(n & 1) + + +def vints_unpack(term): # noqa + values = [] + n = 0 + while n < len(term): + first_byte = term[n] + + if (first_byte & 128) == 0: + val = first_byte + else: + num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() + val = first_byte & (0xff >> num_extra_bytes) + end = n + num_extra_bytes + while n < end: + n += 1 + val <<= 8 + val |= term[n] & 0xff + + n += 1 + values.append(decode_zig_zag(val)) + + return tuple(values) + +def vints_pack(values): + revbytes = bytearray() + values = [int(v) for v in values[::-1]] + for value in values: + v = encode_zig_zag(value) + if v < 128: + revbytes.append(v) + else: + num_extra_bytes = 0 + num_bits = v.bit_length() + # We need to reserve (num_extra_bytes+1) bits in the first byte + # i.e. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved + # i.e. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved + reserved_bits = num_extra_bytes + 1 + while num_bits > (8-(reserved_bits)): + num_extra_bytes += 1 + num_bits -= 8 + reserved_bits = min(num_extra_bytes + 1, 8) + revbytes.append(v & 0xff) + v >>= 8 + + if num_extra_bytes > 8: + raise ValueError('Value %d is too big and cannot be encoded as vint' % value) + + # We can now store the last bits in the first byte + n = 8 - num_extra_bytes + v |= (0xff >> n << n) + revbytes.append(abs(v)) + + revbytes.reverse() + return bytes(revbytes) + +def uvint_unpack(bytes): + first_byte = bytes[0] + + if (first_byte & 128) == 0: + return (first_byte,1) + + num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() + rv = first_byte & (0xff >> num_extra_bytes) + for idx in range(1,num_extra_bytes + 1): + new_byte = bytes[idx] + rv <<= 8 + rv |= new_byte & 0xff + + return (rv, num_extra_bytes + 1) + +def uvint_pack(val): + rv = bytearray() + if val < 128: + rv.append(val) + else: + v = val + num_extra_bytes = 0 + num_bits = v.bit_length() + # We need to reserve (num_extra_bytes+1) bits in the first byte + # i.e. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved + # i.e. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved + reserved_bits = num_extra_bytes + 1 + while num_bits > (8-(reserved_bits)): + num_extra_bytes += 1 + num_bits -= 8 + reserved_bits = min(num_extra_bytes + 1, 8) + rv.append(v & 0xff) + v >>= 8 + + if num_extra_bytes > 8: + raise ValueError('Value %d is too big and cannot be encoded as vint' % val) + + # We can now store the last bits in the first byte + n = 8 - num_extra_bytes + v |= (0xff >> n << n) + rv.append(abs(v)) + + rv.reverse() + return bytes(rv) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 1d04b4c964..2c13f92e42 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from bisect import bisect_right -from collections import defaultdict, Mapping +from binascii import unhexlify +from bisect import bisect_left +from collections import defaultdict +from collections.abc import Mapping +from functools import total_ordering from hashlib import md5 -from itertools import islice, cycle import json import logging import re -import six -from six.moves import zip import sys from threading import RLock +import struct +import random murmur3 = None try: @@ -36,22 +40,31 @@ from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage from cassandra.query import dict_factory, bind_params -from cassandra.util import OrderedDict +from cassandra.util import OrderedDict, Version +from cassandra.pool import HostDistance +from cassandra.connection import EndPoint log = logging.getLogger(__name__) cql_keywords = set(( 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', 'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', - 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'distinct', 'double', 'drop', + 'counter', 'create', 'custom', 'date', 'decimal', 'default', 'delete', 'desc', 'describe', 'deterministic', 'distinct', 'double', 'drop', 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', - 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'nan', 'nologin', - 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', + 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'mbean', 'mbeans', 'modify', 'monotonic', + 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid', - 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user', - 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime' + 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'unset', 'update', 'use', 'user', + 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime', + + # DSE specifics + "node", "nodes", "plan", "active", "application", "applications", "java", "executor", "executors", "std_out", "std_err", + "renew", "delegation", "no", "redact", "token", "lowercasestring", "cluster", "authentication", "schemes", "scheme", + "internal", "ldap", "kerberos", "remote", "object", "method", "call", "calls", "search", "schema", "config", "rows", + "columns", "profiles", "commit", "reload", "rebuild", "field", "workpool", "any", "submission", "indices", + "restrict", "unrestrict" )) """ Set of keywords in CQL. @@ -61,9 +74,9 @@ cql_keywords_unreserved = set(( 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', - 'count', 'counter', 'custom', 'date', 'decimal', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', + 'count', 'counter', 'custom', 'date', 'decimal', 'deterministic', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', - 'language', 'list', 'login', 'map', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', + 'language', 'list', 'login', 'map', 'monotonic', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', 'varint', 'writetime' @@ -103,8 +116,12 @@ class Metadata(object): token_map = None """ A :class:`~.TokenMap` instance describing the ring topology. """ + dbaas = False + """ A boolean indicating if connected to a DBaaS cluster """ + def __init__(self): self.keyspaces = {} + self.dbaas = False self._hosts = {} self._hosts_lock = RLock() @@ -117,30 +134,33 @@ def export_schema_as_string(self): def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): + server_version = self.get_host(connection.endpoint).release_version + dse_version = self.get_host(connection.endpoint).dse_version + parser = get_schema_parser(connection, server_version, dse_version, timeout) + if not target_type: - self._rebuild_all(connection, timeout) + self._rebuild_all(parser) return tt_lower = target_type.lower() try: - parser = get_schema_parser(connection, timeout) parse_method = getattr(parser, 'get_' + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: update_method = getattr(self, '_update_' + tt_lower) - update_method(meta) + if tt_lower == 'keyspace' and connection.protocol_version < 3: + # we didn't have 'type' target in legacy protocol versions, so we need to query those too + user_types = parser.get_types_map(self.keyspaces, **kwargs) + self._update_keyspace(meta, user_types) + else: + update_method(meta) else: drop_method = getattr(self, '_drop_' + tt_lower) drop_method(**kwargs) except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) - def _rebuild_all(self, connection, timeout): - """ - For internal use only. - """ - parser = get_schema_parser(connection, timeout) - + def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) @@ -159,13 +179,13 @@ def _rebuild_all(self, connection, timeout): for ksname in removed_keyspaces: self._keyspace_removed(ksname) - def _update_keyspace(self, keyspace_meta): + def _update_keyspace(self, keyspace_meta, new_user_types=None): ks_name = keyspace_meta.name old_keyspace_meta = self.keyspaces.get(ks_name, None) self.keyspaces[ks_name] = keyspace_meta if old_keyspace_meta: keyspace_meta.tables = old_keyspace_meta.tables - keyspace_meta.user_types = old_keyspace_meta.user_types + keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types keyspace_meta.indexes = old_keyspace_meta.indexes keyspace_meta.functions = old_keyspace_meta.functions keyspace_meta.aggregates = old_keyspace_meta.aggregates @@ -272,9 +292,9 @@ def rebuild_token_map(self, partitioner, token_map): token_to_host_owner = {} ring = [] - for host, token_strings in six.iteritems(token_map): + for host, token_strings in token_map.items(): for token_string in token_strings: - token = token_class(token_string) + token = token_class.from_string(token_string) ring.append(token) token_to_host_owner[token] = host @@ -309,24 +329,40 @@ def add_or_return_host(self, host): """ with self._hosts_lock: try: - return self._hosts[host.address], False + return self._hosts[host.endpoint], False except KeyError: - self._hosts[host.address] = host + self._hosts[host.endpoint] = host return host, True def remove_host(self, host): with self._hosts_lock: - return bool(self._hosts.pop(host.address, False)) + return bool(self._hosts.pop(host.endpoint, False)) + + def get_host(self, endpoint_or_address, port=None): + """ + Find a host in the metadata for a specific endpoint. If a string inet address and port are passed, + iterate all hosts to match the :attr:`~.pool.Host.broadcast_rpc_address` and + :attr:`~.pool.Host.broadcast_rpc_port`attributes. + """ + if not isinstance(endpoint_or_address, EndPoint): + return self._get_host_by_address(endpoint_or_address, port) - def get_host(self, address): - return self._hosts.get(address) + return self._hosts.get(endpoint_or_address) + + def _get_host_by_address(self, address, port=None): + for host in self._hosts.values(): + if (host.broadcast_rpc_address == address and + (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)): + return host + + return None def all_hosts(self): """ Returns a list of all known :class:`.Host` instances in the cluster. """ with self._hosts_lock: - return self._hosts.values() + return list(self._hosts.values()) REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." @@ -350,8 +386,8 @@ def __new__(metacls, name, bases, dct): return cls -@six.add_metaclass(ReplicationStrategyTypeType) -class _ReplicationStrategy(object): + +class _ReplicationStrategy(object, metaclass=ReplicationStrategyTypeType): options_map = None @classmethod @@ -400,9 +436,9 @@ def __init__(self, name, options_map): self.options_map['class'] = self.name def __eq__(self, other): - return (isinstance(other, _UnknownStrategy) - and self.name == other.name - and self.options_map == other.options_map) + return (isinstance(other, _UnknownStrategy) and + self.name == other.name and + self.options_map == other.options_map) def export_for_schema(self): """ @@ -417,18 +453,82 @@ def make_token_replica_map(self, token_to_host_owner, ring): return {} +class ReplicationFactor(object): + """ + Represent the replication factor of a keyspace. + """ + + all_replicas = None + """ + The number of total replicas. + """ + + full_replicas = None + """ + The number of replicas that own a full copy of the data. This is the same + than `all_replicas` when transient replication is not enabled. + """ + + transient_replicas = None + """ + The number of transient replicas. + + Only set if the keyspace has transient replication enabled. + """ + + def __init__(self, all_replicas, transient_replicas=None): + self.all_replicas = all_replicas + self.transient_replicas = transient_replicas + self.full_replicas = (all_replicas - transient_replicas) if transient_replicas else all_replicas + + @staticmethod + def create(rf): + """ + Given the inputted replication factor string, parse and return the ReplicationFactor instance. + """ + transient_replicas = None + try: + all_replicas = int(rf) + except ValueError: + try: + rf = rf.split('/') + all_replicas, transient_replicas = int(rf[0]), int(rf[1]) + except Exception: + raise ValueError("Unable to determine replication factor from: {}".format(rf)) + + return ReplicationFactor(all_replicas, transient_replicas) + + def __str__(self): + return ("%d/%d" % (self.all_replicas, self.transient_replicas) if self.transient_replicas + else "%d" % self.all_replicas) + + def __eq__(self, other): + if not isinstance(other, ReplicationFactor): + return False + + return self.all_replicas == other.all_replicas and self.full_replicas == other.full_replicas + + class SimpleStrategy(ReplicationStrategy): - replication_factor = None + replication_factor_info = None """ - The replication factor for this keyspace. + A :class:`cassandra.metadata.ReplicationFactor` instance. """ + @property + def replication_factor(self): + """ + The replication factor for this keyspace. + + For backward compatibility, this returns the + :attr:`cassandra.metadata.ReplicationFactor.full_replicas` value of + :attr:`cassandra.metadata.SimpleStrategy.replication_factor_info`. + """ + return self.replication_factor_info.full_replicas + def __init__(self, options_map): - try: - self.replication_factor = int(options_map['replication_factor']) - except Exception: - raise ValueError("SimpleStrategy requires an integer 'replication_factor' option") + self.replication_factor_info = ReplicationFactor.create(options_map['replication_factor']) def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} @@ -449,47 +549,59 @@ def export_for_schema(self): Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ - return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ - % (self.replication_factor,) + return "{'class': 'SimpleStrategy', 'replication_factor': '%s'}" \ + % (str(self.replication_factor_info),) def __eq__(self, other): if not isinstance(other, SimpleStrategy): return False - return self.replication_factor == other.replication_factor + return str(self.replication_factor_info) == str(other.replication_factor_info) class NetworkTopologyStrategy(ReplicationStrategy): + dc_replication_factors_info = None + """ + A map of datacenter names to the :class:`cassandra.metadata.ReplicationFactor` instance for that DC. + """ + dc_replication_factors = None """ A map of datacenter names to the replication factor for that DC. + + For backward compatibility, this maps to the :attr:`cassandra.metadata.ReplicationFactor.full_replicas` + value of the :attr:`cassandra.metadata.NetworkTopologyStrategy.dc_replication_factors_info` dict. """ def __init__(self, dc_replication_factors): + self.dc_replication_factors_info = dict( + (str(k), ReplicationFactor.create(v)) for k, v in dc_replication_factors.items()) self.dc_replication_factors = dict( - (str(k), int(v)) for k, v in dc_replication_factors.items()) + (dc, rf.full_replicas) for dc, rf in self.dc_replication_factors_info.items()) def make_token_replica_map(self, token_to_host_owner, ring): - # note: this does not account for hosts having different racks - replica_map = defaultdict(list) - dc_rf_map = dict((dc, int(rf)) - for dc, rf in self.dc_replication_factors.items() if rf > 0) + dc_rf_map = dict( + (dc, full_replicas) for dc, full_replicas in self.dc_replication_factors.items() + if full_replicas > 0) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC dc_to_token_offset = defaultdict(list) dc_racks = defaultdict(set) + hosts_per_dc = defaultdict(set) for i, token in enumerate(ring): host = token_to_host_owner[token] dc_to_token_offset[host.datacenter].append(i) if host.datacenter and host.rack: dc_racks[host.datacenter].add(host.rack) + hosts_per_dc[host.datacenter].add(host) # A map of DCs to an index into the dc_to_token_offset value for that dc. # This is how we keep track of advancing around the ring for each DC. dc_to_current_index = defaultdict(int) + replica_map = defaultdict(list) for i in range(len(ring)): replicas = replica_map[ring[i]] @@ -508,12 +620,19 @@ def make_token_replica_map(self, token_to_host_owner, ring): dc_to_current_index[dc] = index replicas_remaining = dc_rf_map[dc] + replicas_this_dc = 0 skipped_hosts = [] racks_placed = set() racks_this_dc = dc_racks[dc] - for token_offset in islice(cycle(token_offsets), index, index + num_tokens): + hosts_this_dc = len(hosts_per_dc[dc]) + + for token_offset_index in range(index, index+num_tokens): + if token_offset_index >= len(token_offsets): + token_offset_index = token_offset_index - len(token_offsets) + + token_offset = token_offsets[token_offset_index] host = token_to_host_owner[ring[token_offset]] - if replicas_remaining == 0: + if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc: break if host in replicas: @@ -524,6 +643,7 @@ def make_token_replica_map(self, token_to_host_owner, ring): continue replicas.append(host) + replicas_this_dc += 1 replicas_remaining -= 1 racks_placed.add(host.rack) @@ -543,15 +663,15 @@ def export_for_schema(self): suitable for use in a CREATE KEYSPACE statement. """ ret = "{'class': 'NetworkTopologyStrategy'" - for dc, repl_factor in sorted(self.dc_replication_factors.items()): - ret += ", '%s': '%d'" % (dc, repl_factor) + for dc, rf in sorted(self.dc_replication_factors_info.items()): + ret += ", '%s': '%s'" % (dc, str(rf)) return ret + "}" def __eq__(self, other): if not isinstance(other, NetworkTopologyStrategy): return False - return self.dc_replication_factors == other.dc_replication_factors + return self.dc_replication_factors_info == other.dc_replication_factors_info class LocalStrategy(ReplicationStrategy): @@ -627,10 +747,23 @@ class KeyspaceMetadata(object): A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ + virtual = False + """ + A boolean indicating if this is a virtual keyspace or not. Always ``False`` + for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions. + + .. versionadded:: 3.15 + """ + + graph_engine = None + """ + A string indicating whether a graph engine is enabled for this keyspace (Core/Classic). + """ + _exc_info = None """ set if metadata parsing failed """ - def __init__(self, name, durable_writes, strategy_class, strategy_options): + def __init__(self, name, durable_writes, strategy_class, strategy_options, graph_engine=None): self.name = name self.durable_writes = durable_writes self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) @@ -640,17 +773,28 @@ def __init__(self, name, durable_writes, strategy_class, strategy_options): self.functions = {} self.aggregates = {} self.views = {} + self.graph_engine = graph_engine + + @property + def is_graph_enabled(self): + return self.graph_engine is not None def export_as_string(self): """ Returns a CQL query string that can be used to recreate the entire keyspace, including user-defined types and tables. """ - cql = "\n\n".join([self.as_cql_query() + ';'] - + self.user_type_strings() - + [f.export_as_string() for f in self.functions.values()] - + [a.export_as_string() for a in self.aggregates.values()] - + [t.export_as_string() for t in self.tables.values()]) + # Make sure tables with vertex are exported before tables with edges + tables_with_vertex = [t for t in self.tables.values() if hasattr(t, 'vertex') and t.vertex] + other_tables = [t for t in self.tables.values() if t not in tables_with_vertex] + + cql = "\n\n".join( + [self.as_cql_query() + ';'] + + self.user_type_strings() + + [f.export_as_string() for f in self.functions.values()] + + [a.export_as_string() for a in self.aggregates.values()] + + [t.export_as_string() for t in tables_with_vertex + other_tables]) + if self._exc_info: import traceback ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ @@ -659,6 +803,11 @@ def export_as_string(self): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql return ret + if self.virtual: + return ("/*\nWarning: Keyspace {ks} is a virtual keyspace and cannot be recreated with CQL.\n" + "Structure, for reference:*/\n" + "{cql}\n" + "").format(ks=self.name, cql=cql) return cql def as_cql_query(self): @@ -666,10 +815,15 @@ def as_cql_query(self): Returns a CQL query string that can be used to recreate just this keyspace, not including user-defined types and tables. """ + if self.virtual: + return "// VIRTUAL KEYSPACE {}".format(protect_name(self.name)) ret = "CREATE KEYSPACE %s WITH replication = %s " % ( protect_name(self.name), self.replication_strategy.export_for_schema()) - return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) + ret = ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) + if self.graph_engine is not None: + ret = ret + (" AND graph_engine = '%s'" % self.graph_engine) + return ret def user_type_strings(self): user_type_strings = [] @@ -699,7 +853,7 @@ def _add_table_metadata(self, table_metadata): # note the intentional order of add before remove # this makes sure the maps are never absent something that existed before this update - for index_name, index_metadata in six.iteritems(table_metadata.indexes): + for index_name, index_metadata in table_metadata.indexes.items(): self.indexes[index_name] = index_metadata for index_name in (n for n in old_indexes if n not in table_metadata.indexes): @@ -847,8 +1001,15 @@ class Aggregate(object): Type of the aggregate state """ + deterministic = None + """ + Flag indicating if this function is guaranteed to produce the same result + for a particular input and state. This is available only with DSE >=6.0. + """ + def __init__(self, keyspace, name, argument_types, state_func, - state_type, final_func, initial_condition, return_type): + state_type, final_func, initial_condition, return_type, + deterministic): self.keyspace = keyspace self.name = name self.argument_types = argument_types @@ -857,6 +1018,7 @@ def __init__(self, keyspace, name, argument_types, state_func, self.final_func = final_func self.initial_condition = initial_condition self.return_type = return_type + self.deterministic = deterministic def as_cql_query(self, formatted=False): """ @@ -867,9 +1029,9 @@ def as_cql_query(self, formatted=False): sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) - type_list = ', '.join(self.argument_types) + type_list = ', '.join([types.strip_frozen(arg_type) for arg_type in self.argument_types]) state_func = protect_name(self.state_func) - state_type = self.state_type + state_type = types.strip_frozen(self.state_type) ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ "SFUNC %(state_func)s%(sep)s" \ @@ -877,6 +1039,7 @@ def as_cql_query(self, formatted=False): ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' + ret += '{}DETERMINISTIC'.format(sep) if self.deterministic else '' return ret @@ -938,8 +1101,27 @@ class Function(object): (convenience function to avoid handling nulls explicitly if the result will just be null) """ + deterministic = None + """ + Flag indicating if this function is guaranteed to produce the same result + for a particular input. This is available only for DSE >=6.0. + """ + + monotonic = None + """ + Flag indicating if this function is guaranteed to increase or decrease + monotonically on any of its arguments. This is available only for DSE >=6.0. + """ + + monotonic_on = None + """ + A list containing the argument or arguments over which this function is + monotonic. This is available only for DSE >=6.0. + """ + def __init__(self, keyspace, name, argument_types, argument_names, - return_type, language, body, called_on_null_input): + return_type, language, body, called_on_null_input, + deterministic, monotonic, monotonic_on): self.keyspace = keyspace self.name = name self.argument_types = argument_types @@ -950,6 +1132,9 @@ def __init__(self, keyspace, name, argument_types, argument_names, self.language = language self.body = body self.called_on_null_input = called_on_null_input + self.deterministic = deterministic + self.monotonic = monotonic + self.monotonic_on = monotonic_on def as_cql_query(self, formatted=False): """ @@ -960,16 +1145,31 @@ def as_cql_query(self, formatted=False): sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) - arg_list = ', '.join(["%s %s" % (protect_name(n), t) + arg_list = ', '.join(["%s %s" % (protect_name(n), types.strip_frozen(t)) for n, t in zip(self.argument_names, self.argument_types)]) typ = self.return_type lang = self.language body = self.body on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" + deterministic_token = ('DETERMINISTIC{}'.format(sep) + if self.deterministic else + '') + monotonic_tokens = '' # default for nonmonotonic function + if self.monotonic: + # monotonic on all arguments; ignore self.monotonic_on + monotonic_tokens = 'MONOTONIC{}'.format(sep) + elif self.monotonic_on: + # if monotonic == False and monotonic_on is nonempty, we know that + # monotonicity was specified with MONOTONIC ON , so there's + # exactly 1 value there + monotonic_tokens = 'MONOTONIC ON {}{}'.format(self.monotonic_on[0], + sep) return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ "%(on_null)s ON NULL INPUT%(sep)s" \ "RETURNS %(typ)s%(sep)s" \ + "%(deterministic_token)s" \ + "%(monotonic_tokens)s" \ "LANGUAGE %(lang)s%(sep)s" \ "AS $$%(body)s$$" % locals() @@ -1053,26 +1253,38 @@ def primary_key(self): _exc_info = None """ set if metadata parsing failed """ + virtual = False + """ + A boolean indicating if this is a virtual table or not. Always ``False`` + for clusters running Cassandra pre-4.0 and DSE pre-6.7 versions. + + .. versionadded:: 3.15 + """ + @property def is_cql_compatible(self): """ A boolean indicating if this table can be represented as CQL in export """ + if self.virtual: + return False comparator = getattr(self, 'comparator', None) if comparator: - # no such thing as DCT in CQL - incompatible = issubclass(self.comparator, types.DynamicCompositeType) - # no compact storage with more than one column beyond PK if there # are clustering columns - incompatible |= (self.is_compact_storage and - len(self.columns) > len(self.primary_key) + 1 and - len(self.clustering_key) >= 1) + incompatible = (self.is_compact_storage and + len(self.columns) > len(self.primary_key) + 1 and + len(self.clustering_key) >= 1) return not incompatible return True - def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None): + extensions = None + """ + Metadata describing configuration for table extensions + """ + + def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None, virtual=False): self.keyspace_name = keyspace_name self.name = name self.partition_key = [] if partition_key is None else partition_key @@ -1083,6 +1295,7 @@ def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, self.comparator = None self.triggers = OrderedDict() if triggers is None else triggers self.views = {} + self.virtual = virtual def export_as_string(self): """ @@ -1102,6 +1315,11 @@ def export_as_string(self): ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ (self.keyspace_name, self.name) ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() + elif self.virtual: + ret = ('/*\nWarning: Table {ks}.{tab} is a virtual table and cannot be recreated with CQL.\n' + 'Structure, for reference:\n' + '{cql}\n*/').format(ks=self.keyspace_name, tab=self.name, cql=self._all_as_cql()) + else: ret = self._all_as_cql() @@ -1120,6 +1338,14 @@ def _all_as_cql(self): for view_meta in self.views.values(): ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) + if self.extensions: + registry = _RegisteredExtensionType._extension_registry + for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + ext = registry[k] + cql = ext.after_table_cql(self, k, self.extensions[k]) + if cql: + ret += "\n\n%s" % (cql,) + return ret def as_cql_query(self, formatted=False): @@ -1128,7 +1354,8 @@ def as_cql_query(self, formatted=False): creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ - ret = "CREATE TABLE %s.%s (%s" % ( + ret = "%s TABLE %s.%s (%s" % ( + ('VIRTUAL' if self.virtual else 'CREATE'), protect_name(self.keyspace_name), protect_name(self.name), "\n" if formatted else "") @@ -1221,14 +1448,126 @@ def _make_option_strings(cls, options_map): return list(sorted(ret)) -if six.PY3: - def protect_name(name): - return maybe_escape_name(name) -else: - def protect_name(name): # NOQA - if isinstance(name, six.text_type): - name = name.encode('utf8') - return maybe_escape_name(name) +class TableMetadataV3(TableMetadata): + """ + For C* 3.0+. `option_maps` take a superset of map names, so if nothing + changes structurally, new option maps can just be appended to the list. + """ + compaction_options = {} + + option_maps = [ + 'compaction', 'compression', 'caching', + 'nodesync' # added DSE 6.0 + ] + + @property + def is_cql_compatible(self): + return True + + @classmethod + def _make_option_strings(cls, options_map): + ret = [] + options_copy = dict(options_map.items()) + + for option in cls.option_maps: + value = options_copy.get(option) + if isinstance(value, Mapping): + del options_copy[option] + params = ("'%s': '%s'" % (k, v) for k, v in value.items()) + ret.append("%s = {%s}" % (option, ', '.join(params))) + + for name, value in options_copy.items(): + if value is not None: + if name == "comment": + value = value or "" + ret.append("%s = %s" % (name, protect_value(value))) + + return list(sorted(ret)) + + +class TableMetadataDSE68(TableMetadataV3): + + vertex = None + """A :class:`.VertexMetadata` instance, if graph enabled""" + + edge = None + """A :class:`.EdgeMetadata` instance, if graph enabled""" + + def as_cql_query(self, formatted=False): + ret = super(TableMetadataDSE68, self).as_cql_query(formatted) + + if self.vertex: + ret += " AND VERTEX LABEL %s" % protect_name(self.vertex.label_name) + + if self.edge: + ret += " AND EDGE LABEL %s" % protect_name(self.edge.label_name) + + ret += self._export_edge_as_cql( + self.edge.from_label, + self.edge.from_partition_key_columns, + self.edge.from_clustering_columns, "FROM") + + ret += self._export_edge_as_cql( + self.edge.to_label, + self.edge.to_partition_key_columns, + self.edge.to_clustering_columns, "TO") + + return ret + + @staticmethod + def _export_edge_as_cql(label_name, partition_keys, + clustering_columns, keyword): + ret = " %s %s(" % (keyword, protect_name(label_name)) + + if len(partition_keys) == 1: + ret += protect_name(partition_keys[0]) + else: + ret += "(%s)" % ", ".join([protect_name(k) for k in partition_keys]) + + if clustering_columns: + ret += ", %s" % ", ".join([protect_name(k) for k in clustering_columns]) + ret += ")" + + return ret + + +class TableExtensionInterface(object): + """ + Defines CQL/DDL for Cassandra table extensions. + """ + # limited API for now. Could be expanded as new extension types materialize -- "extend_option_strings", for example + @classmethod + def after_table_cql(cls, ext_key, ext_blob): + """ + Called to produce CQL/DDL to follow the table definition. + Should contain requisite terminating semicolon(s). + """ + pass + + +class _RegisteredExtensionType(type): + + _extension_registry = {} + + def __new__(mcs, name, bases, dct): + cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct) + if name != 'RegisteredTableExtension': + mcs._extension_registry[cls.name] = cls + return cls + + +class RegisteredTableExtension(TableExtensionInterface, metaclass=_RegisteredExtensionType): + """ + Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map). + """ + name = None + """ + Name of the extension (key in the map) + """ + + +def protect_name(name): + return maybe_escape_name(name) def protect_names(names): @@ -1338,20 +1677,22 @@ def as_cql_query(self): index_target = options.pop("target") if self.kind != "CUSTOM": return "CREATE INDEX %s ON %s.%s (%s)" % ( - self.name, # Cassandra doesn't like quoted index names for some reason + protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target) else: class_name = options.pop("class_name") ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( - self.name, # Cassandra doesn't like quoted index names for some reason + protect_name(self.name), protect_name(self.keyspace_name), protect_name(self.table_name), index_target, class_name) if options: - ret += " WITH OPTIONS = %s" % Encoder().cql_encode_all_types(options) + # PYTHON-1008: `ret` will always be a unicode + opts_cql_encoded = _encoder.cql_encode_all_types(options, as_text_type=True) + ret += " WITH OPTIONS = %s" % opts_cql_encoded return ret def export_as_string(self): @@ -1400,10 +1741,18 @@ def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: - current = self.tokens_to_hosts_by_ks.get(keyspace, None) - if (build_if_absent and current is None) or (not build_if_absent and current is not None): - replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) - self.tokens_to_hosts_by_ks[keyspace] = replica_map + try: + current = self.tokens_to_hosts_by_ks.get(keyspace, None) + if (build_if_absent and current is None) or (not build_if_absent and current is not None): + ks_meta = self._metadata.keyspaces.get(keyspace) + if ks_meta: + replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) + self.tokens_to_hosts_by_ks[keyspace] = replica_map + except Exception: + # should not happen normally, but we don't want to blow up queries because of unexpected meta state + # bypass until new map is generated + self.tokens_to_hosts_by_ks[keyspace] = {} + log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner) def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy @@ -1426,10 +1775,9 @@ def get_replicas(self, keyspace, token): tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts: - # token range ownership is exclusive on the LHS (the start token), so - # we use bisect_right, which, in the case of a tie/exact match, - # picks an insertion point to the right of the existing match - point = bisect_right(self.ring, token) + # The values in self.ring correspond to the end of the + # token range up to and including the value listed. + point = bisect_left(self.ring, token) if point == len(self.ring): return tokens_to_hosts[self.ring[0]] else: @@ -1437,11 +1785,15 @@ def get_replicas(self, keyspace, token): return [] +@total_ordering class Token(object): """ Abstract class representing a token. """ + def __init__(self, token): + self.value = token + @classmethod def hash_fn(cls, key): return key @@ -1450,13 +1802,9 @@ def hash_fn(cls, key): def from_key(cls, key): return cls(cls.hash_fn(key)) - def __cmp__(self, other): - if self.value < other.value: - return -1 - elif self.value == other.value: - return 0 - else: - return 1 + @classmethod + def from_string(cls, token_string): + raise NotImplementedError() def __eq__(self, other): return self.value == other.value @@ -1471,6 +1819,7 @@ def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.value) __str__ = __repr__ + MIN_LONG = -(2 ** 63) MAX_LONG = (2 ** 63) - 1 @@ -1479,7 +1828,16 @@ class NoMurmur3(Exception): pass -class Murmur3Token(Token): +class HashToken(Token): + + @classmethod + def from_string(cls, token_string): + """ `token_string` should be the string representation from the server. """ + # The hash partitioners just store the deciman value + return cls(int(token_string)) + + +class Murmur3Token(HashToken): """ A token for ``Murmur3Partitioner``. """ @@ -1493,38 +1851,35 @@ def hash_fn(cls, key): raise NoMurmur3() def __init__(self, token): - """ `token` should be an int or string representing the token. """ + """ `token` is an int or string representing the token. """ self.value = int(token) -class MD5Token(Token): +class MD5Token(HashToken): """ A token for ``RandomPartitioner``. """ @classmethod def hash_fn(cls, key): - if isinstance(key, six.text_type): + if isinstance(key, str): key = key.encode('UTF-8') return abs(varint_unpack(md5(key).digest())) - def __init__(self, token): - """ `token` should be an int or string representing the token. """ - self.value = int(token) - class BytesToken(Token): """ A token for ``ByteOrderedPartitioner``. """ - def __init__(self, token_string): - """ `token_string` should be string representing the token. """ - if not isinstance(token_string, six.string_types): - raise TypeError( - "Tokens for ByteOrderedPartitioner should be strings (got %s)" - % (type(token_string),)) - self.value = token_string + @classmethod + def from_string(cls, token_string): + """ `token_string` should be the string representation from the server. """ + # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" + if isinstance(token_string, str): + token_string = token_string.encode('ascii') + # The BOP stores a hex string + return cls(unhexlify(token_string)) class TriggerMetadata(object): @@ -1567,21 +1922,56 @@ def __init__(self, connection, timeout): self.connection = connection self.timeout = timeout - def _handle_results(self, success, result): - if success: - return dict_factory(*result.results) if result else [] + def _handle_results(self, success, result, expected_failures=tuple()): + """ + Given a bool and a ResultSet (the form returned per result from + Connection.wait_for_responses), return a dictionary containing the + results. Used to process results from asynchronous queries to system + tables. + + ``expected_failures`` will usually be used to allow callers to ignore + ``InvalidRequest`` errors caused by a missing system keyspace. For + example, some DSE versions report a 4.X server version, but do not have + virtual tables. Thus, running against 4.X servers, SchemaParserV4 uses + expected_failures to make a best-effort attempt to read those + keyspaces, but treat them as empty if they're not found. + + :param success: A boolean representing whether or not the query + succeeded + :param result: The resultset in question. + :expected_failures: An Exception class or an iterable thereof. If the + query failed, but raised an instance of an expected failure class, this + will ignore the failure and return an empty list. + """ + if not success and isinstance(result, expected_failures): + return [] + elif success: + return dict_factory(result.column_names, result.parsed_rows) if result else [] else: raise result def _query_build_row(self, query_string, build_func): + result = self._query_build_rows(query_string, build_func) + return result[0] if result else None + + def _query_build_rows(self, query_string, build_func): query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE) - response = self.connection.wait_for_response(query, self.timeout) - result = dict_factory(*response.results) - if result: - return build_func(result[0]) + responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) + (success, response) = responses[0] + if success: + result = dict_factory(response.column_names, response.parsed_rows) + return [build_func(row) for row in result] + elif isinstance(response, InvalidRequest): + log.debug("user types table not found") + return [] + else: + raise response class SchemaParserV22(_SchemaParser): + """ + For C* 2.2+ + """ _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" @@ -1674,11 +2064,9 @@ def get_table(self, keyspaces, keyspace, table): table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) - # handle the triggers table not existing in Cassandra 1.2 - if not triggers_success and isinstance(triggers_result, InvalidRequest): - triggers_result = [] - else: - triggers_result = self._handle_results(triggers_success, triggers_result) + # the triggers table doesn't exist in C* 1.2 + triggers_result = self._handle_results(triggers_success, triggers_result, + expected_failures=InvalidRequest) if table_result: return self._build_table_metadata(table_result[0], col_result, triggers_result) @@ -1687,6 +2075,11 @@ def get_type(self, keyspaces, keyspace, type): where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) + def get_types_map(self, keyspaces, keyspace): + where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) + types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) + return dict((t.name, t) for t in types) + def get_function(self, keyspaces, keyspace, function): where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, function.name, function.argument_types), _encoder) @@ -1730,10 +2123,14 @@ def _build_user_type(cls, usertype_row): @classmethod def _build_function(cls, function_row): return_type = cls._schema_type_to_cql(function_row['return_type']) + deterministic = function_row.get('deterministic', False) + monotonic = function_row.get('monotonic', False) + monotonic_on = function_row.get('monotonic_on', ()) return Function(function_row['keyspace_name'], function_row['function_name'], function_row[cls._function_agg_arument_type_col], function_row['argument_names'], return_type, function_row['language'], function_row['body'], - function_row['called_on_null_input']) + function_row['called_on_null_input'], + deterministic, monotonic, monotonic_on) @classmethod def _build_aggregate(cls, aggregate_row): @@ -1745,7 +2142,8 @@ def _build_aggregate(cls, aggregate_row): return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['signature'], aggregate_row['state_func'], state_type, - aggregate_row['final_func'], initial_condition, return_type) + aggregate_row['final_func'], initial_condition, return_type, + aggregate_row.get('deterministic', False)) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): keyspace_name = row["keyspace_name"] @@ -1764,12 +2162,9 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator - if issubclass(comparator, types.CompositeType): - column_name_types = comparator.subtypes - is_composite_comparator = True - else: - column_name_types = (comparator,) - is_composite_comparator = False + is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) + is_composite_comparator = issubclass(comparator, types.CompositeType) + column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] @@ -1783,7 +2178,8 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if column_aliases is not None: column_aliases = json.loads(column_aliases) - else: + + if not column_aliases: # json load failed or column_aliases empty PYTHON-562 column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: @@ -1792,8 +2188,8 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): is_compact = False has_value = False clustering_size = num_column_name_components - 2 - elif (len(column_aliases) == num_column_name_components - 1 - and issubclass(last_col, types.UTF8Type)): + elif (len(column_aliases) == num_column_name_components - 1 and + issubclass(last_col, types.UTF8Type)): # aliases? is_compact = False has_value = False @@ -1806,10 +2202,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): - column_aliases = comparator.fieldnames + column_aliases = filter(None, comparator.fieldnames) else: is_compact = True - if column_aliases or not col_rows: + if column_aliases or not col_rows or is_dct_comparator: has_value = True clustering_size = num_column_name_components else: @@ -1854,7 +2250,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if len(column_aliases) > i: column_name = column_aliases[i] else: - column_name = "column%d" % i + column_name = "column%d" % (i + 1) data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) @@ -1891,7 +2287,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): # other normal columns for col_row in col_rows: column_meta = self._build_column_metadata(table_meta, col_row) - if column_meta.name: + if column_meta.name is not None: table_meta.columns[column_meta.name] = column_meta index_meta = self._build_index_metadata(column_meta, col_row) if index_meta: @@ -1985,12 +2381,16 @@ def _query_all(self): QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) ] - responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) - (ks_success, ks_result), (table_success, table_result), \ - (col_success, col_result), (types_success, types_result), \ - (functions_success, functions_result), \ - (aggregates_success, aggregates_result), \ - (triggers_success, triggers_result) = responses + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result)) = ( + self.connection.wait_for_responses(*queries, timeout=self.timeout, + fail_on_error=False) + ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) @@ -1998,7 +2398,7 @@ def _query_all(self): # if we're connected to Cassandra < 2.0, the triggers table will not exist if triggers_success: - self.triggers_result = dict_factory(*triggers_result.results) + self.triggers_result = dict_factory(triggers_result.column_names, triggers_result.parsed_rows) else: if isinstance(triggers_result, InvalidRequest): log.debug("triggers table not found") @@ -2010,7 +2410,7 @@ def _query_all(self): # if we're connected to Cassandra < 2.1, the usertypes table will not exist if types_success: - self.types_result = dict_factory(*types_result.results) + self.types_result = dict_factory(types_result.column_names, types_result.parsed_rows) else: if isinstance(types_result, InvalidRequest): log.debug("user types table not found") @@ -2020,7 +2420,7 @@ def _query_all(self): # functions were introduced in Cassandra 2.2 if functions_success: - self.functions_result = dict_factory(*functions_result.results) + self.functions_result = dict_factory(functions_result.column_names, functions_result.parsed_rows) else: if isinstance(functions_result, InvalidRequest): log.debug("user functions table not found") @@ -2029,7 +2429,7 @@ def _query_all(self): # aggregates were introduced in Cassandra 2.2 if aggregates_success: - self.aggregates_result = dict_factory(*aggregates_result.results) + self.aggregates_result = dict_factory(aggregates_result.column_names, aggregates_result.parsed_rows) else: if isinstance(aggregates_result, InvalidRequest): log.debug("user aggregates table not found") @@ -2074,6 +2474,9 @@ def _schema_type_to_cql(type_string): class SchemaParserV3(SchemaParserV22): + """ + For C* 3.0+ + """ _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" @@ -2088,9 +2491,12 @@ class SchemaParserV3(SchemaParserV22): _function_agg_arument_type_col = 'argument_types' + _table_metadata_class = TableMetadataV3 + recognized_table_options = ( 'bloom_filter_fp_chance', 'caching', + 'cdc', 'comment', 'compaction', 'compression', @@ -2129,10 +2535,13 @@ def get_table(self, keyspaces, keyspace, table): where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause, consistency_level=cl) - (cf_success, cf_result), (col_success, col_result), (indexes_sucess, indexes_result), \ - (triggers_success, triggers_result), (view_success, view_result) \ - = self.connection.wait_for_responses(cf_query, col_query, indexes_query, triggers_query, view_query, - timeout=self.timeout, fail_on_error=False) + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), (triggers_success, triggers_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, triggers_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) if table_result: @@ -2156,9 +2565,10 @@ def _build_keyspace_metadata_internal(row): def _build_aggregate(aggregate_row): return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], - aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type']) + aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type'], + aggregate_row.get('deterministic', False)) - def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None): + def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None, virtual=False): keyspace_name = row["keyspace_name"] table_name = row[self._table_name_col] @@ -2166,20 +2576,24 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] - table_meta = TableMetadataV3(keyspace_name, table_name) + table_meta = self._table_metadata_class(keyspace_name, table_name, virtual=virtual) try: table_meta.options = self._build_table_options(row) flags = row.get('flags', set()) if flags: - compact_static = False - table_meta.is_compact_storage = 'dense' in flags or 'super' in flags or 'compound' not in flags is_dense = 'dense' in flags + compact_static = not is_dense and 'super' not in flags and 'compound' not in flags + table_meta.is_compact_storage = is_dense or 'super' in flags or 'compound' not in flags + elif virtual: + compact_static = False + table_meta.is_compact_storage = False + is_dense = False else: compact_static = True table_meta.is_compact_storage = True is_dense = False - self._build_table_columns(table_meta, col_rows, compact_static, is_dense) + self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) @@ -2189,6 +2603,8 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row index_meta = self._build_index_metadata(table_meta, index_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta + + table_meta.extensions = row.get('extensions', {}) except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) @@ -2199,7 +2615,7 @@ def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) - def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False): + def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False, virtual=False): # partition key partition_rows = [r for r in col_rows if r.get('kind', None) == "partition_key"] @@ -2246,6 +2662,7 @@ def _build_view_metadata(self, row, col_rows=None): view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, include_all_columns, where_clause, self._build_table_options(row)) self._build_table_columns(view_meta, col_rows) + view_meta.extensions = row.get('extensions', {}) return view_meta @@ -2289,14 +2706,17 @@ def _query_all(self): QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl) ] - responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) - (ks_success, ks_result), (table_success, table_result), \ - (col_success, col_result), (types_success, types_result), \ - (functions_success, functions_result), \ - (aggregates_success, aggregates_result), \ - (triggers_success, triggers_result), \ - (indexes_success, indexes_result), \ - (views_success, views_result) = responses + ((ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result)) = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False + ) self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) @@ -2328,34 +2748,351 @@ def _schema_type_to_cql(type_string): return type_string -class TableMetadataV3(TableMetadata): - compaction_options = {} +class SchemaParserDSE60(SchemaParserV3): + """ + For DSE 6.0+ + """ + recognized_table_options = (SchemaParserV3.recognized_table_options + + ("nodesync",)) - option_maps = ['compaction', 'compression', 'caching'] - @property - def is_cql_compatible(self): - return True +class SchemaParserV4(SchemaParserV3): - @classmethod - def _make_option_strings(cls, options_map): - ret = [] - options_copy = dict(options_map.items()) + recognized_table_options = ( + 'additional_write_policy', + 'bloom_filter_fp_chance', + 'caching', + 'cdc', + 'comment', + 'compaction', + 'compression', + 'crc_check_chance', + 'default_time_to_live', + 'gc_grace_seconds', + 'max_index_interval', + 'memtable_flush_period_in_ms', + 'min_index_interval', + 'read_repair', + 'speculative_retry') - for option in cls.option_maps: - value = options_copy.get(option) - if isinstance(value, Mapping): - del options_copy[option] - params = ("'%s': '%s'" % (k, v) for k, v in value.items()) - ret.append("%s = {%s}" % (option, ', '.join(params))) + _SELECT_VIRTUAL_KEYSPACES = 'SELECT * from system_virtual_schema.keyspaces' + _SELECT_VIRTUAL_TABLES = 'SELECT * from system_virtual_schema.tables' + _SELECT_VIRTUAL_COLUMNS = 'SELECT * from system_virtual_schema.columns' - for name, value in options_copy.items(): - if value is not None: - if name == "comment": - value = value or "" - ret.append("%s = %s" % (name, protect_value(value))) + def __init__(self, connection, timeout): + super(SchemaParserV4, self).__init__(connection, timeout) + self.virtual_keyspaces_rows = defaultdict(list) + self.virtual_tables_rows = defaultdict(list) + self.virtual_columns_rows = defaultdict(lambda: defaultdict(list)) - return list(sorted(ret)) + def _query_all(self): + cl = ConsistencyLevel.ONE + # todo: this duplicates V3; we should find a way for _query_all methods + # to extend each other. + queries = [ + # copied from V3 + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), + # V4-only queries + QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl) + ] + + responses = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False) + ( + # copied from V3 + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + # V4-only responses + (virtual_ks_success, virtual_ks_result), + (virtual_table_success, virtual_table_result), + (virtual_column_success, virtual_column_result) + ) = responses + + # copied from V3 + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + # V4-only results + # These tables don't exist in some DSE versions reporting 4.X so we can + # ignore them if we got an error + self.virtual_keyspaces_result = self._handle_results( + virtual_ks_success, virtual_ks_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_tables_result = self._handle_results( + virtual_table_success, virtual_table_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_columns_result = self._handle_results( + virtual_column_success, virtual_column_result, + expected_failures=(InvalidRequest,) + ) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserV4, self)._aggregate_results() + + m = self.virtual_tables_rows + for row in self.virtual_tables_result: + m[row["keyspace_name"]].append(row) + + m = self.virtual_columns_rows + for row in self.virtual_columns_result: + ks_name = row['keyspace_name'] + tab_name = row[self._table_name_col] + m[ks_name][tab_name].append(row) + + def get_all_keyspaces(self): + for x in super(SchemaParserV4, self).get_all_keyspaces(): + yield x + + for row in self.virtual_keyspaces_result: + ks_name = row['keyspace_name'] + keyspace_meta = self._build_keyspace_metadata(row) + keyspace_meta.virtual = True + + for table_row in self.virtual_tables_rows.get(ks_name, []): + table_name = table_row[self._table_name_col] + + col_rows = self.virtual_columns_rows[ks_name][table_name] + keyspace_meta._add_table_metadata( + self._build_table_metadata(table_row, + col_rows=col_rows, + virtual=True) + ) + yield keyspace_meta + + @staticmethod + def _build_keyspace_metadata_internal(row): + # necessary fields that aren't int virtual ks + row["durable_writes"] = row.get("durable_writes", None) + row["replication"] = row.get("replication", {}) + row["replication"]["class"] = row["replication"].get("class", None) + return super(SchemaParserV4, SchemaParserV4)._build_keyspace_metadata_internal(row) + + +class SchemaParserDSE67(SchemaParserV4): + """ + For DSE 6.7+ + """ + recognized_table_options = (SchemaParserV4.recognized_table_options + + ("nodesync",)) + + +class SchemaParserDSE68(SchemaParserDSE67): + """ + For DSE 6.8+ + """ + + _SELECT_VERTICES = "SELECT * FROM system_schema.vertices" + _SELECT_EDGES = "SELECT * FROM system_schema.edges" + + _table_metadata_class = TableMetadataDSE68 + + def __init__(self, connection, timeout): + super(SchemaParserDSE68, self).__init__(connection, timeout) + self.keyspace_table_vertex_rows = defaultdict(lambda: defaultdict(list)) + self.keyspace_table_edge_rows = defaultdict(lambda: defaultdict(list)) + + def get_all_keyspaces(self): + for keyspace_meta in super(SchemaParserDSE68, self).get_all_keyspaces(): + self._build_graph_metadata(keyspace_meta) + yield keyspace_meta + + def get_table(self, keyspaces, keyspace, table): + table_meta = super(SchemaParserDSE68, self).get_table(keyspaces, keyspace, table) + cl = ConsistencyLevel.ONE + where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) + vertices_query = QueryMessage(query=self._SELECT_VERTICES + where_clause, consistency_level=cl) + edges_query = QueryMessage(query=self._SELECT_EDGES + where_clause, consistency_level=cl) + + (vertices_success, vertices_result), (edges_success, edges_result) \ + = self.connection.wait_for_responses(vertices_query, edges_query, timeout=self.timeout, fail_on_error=False) + vertices_result = self._handle_results(vertices_success, vertices_result) + edges_result = self._handle_results(edges_success, edges_result) + + try: + if vertices_result: + table_meta.vertex = self._build_table_vertex_metadata(vertices_result[0]) + elif edges_result: + table_meta.edge = self._build_table_edge_metadata(keyspaces[keyspace], edges_result[0]) + except Exception: + table_meta.vertex = None + table_meta.edge = None + table_meta._exc_info = sys.exc_info() + log.exception("Error while parsing graph metadata for table %s.%s.", keyspace, table) + + return table_meta + + @staticmethod + def _build_keyspace_metadata_internal(row): + name = row["keyspace_name"] + durable_writes = row.get("durable_writes", None) + replication = dict(row.get("replication")) if 'replication' in row else {} + replication_class = replication.pop("class") if 'class' in replication else None + graph_engine = row.get("graph_engine", None) + return KeyspaceMetadata(name, durable_writes, replication_class, replication, graph_engine) + + def _build_graph_metadata(self, keyspace_meta): + + def _build_table_graph_metadata(table_meta): + for row in self.keyspace_table_vertex_rows[keyspace_meta.name][table_meta.name]: + table_meta.vertex = self._build_table_vertex_metadata(row) + + for row in self.keyspace_table_edge_rows[keyspace_meta.name][table_meta.name]: + table_meta.edge = self._build_table_edge_metadata(keyspace_meta, row) + + try: + # Make sure we process vertices before edges + for table_meta in [t for t in keyspace_meta.tables.values() + if t.name in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + _build_table_graph_metadata(table_meta) + + # all other tables... + for table_meta in [t for t in keyspace_meta.tables.values() + if t.name not in self.keyspace_table_vertex_rows[keyspace_meta.name]]: + _build_table_graph_metadata(table_meta) + except Exception: + # schema error, remove all graph metadata for this keyspace + for t in keyspace_meta.tables.values(): + t.edge = t.vertex = None + keyspace_meta._exc_info = sys.exc_info() + log.exception("Error while parsing graph metadata for keyspace %s", keyspace_meta.name) + + @staticmethod + def _build_table_vertex_metadata(row): + return VertexMetadata(row.get("keyspace_name"), row.get("table_name"), + row.get("label_name")) + + @staticmethod + def _build_table_edge_metadata(keyspace_meta, row): + from_table = row.get("from_table") + from_table_meta = keyspace_meta.tables.get(from_table) + from_label = from_table_meta.vertex.label_name + to_table = row.get("to_table") + to_table_meta = keyspace_meta.tables.get(to_table) + to_label = to_table_meta.vertex.label_name + + return EdgeMetadata( + row.get("keyspace_name"), row.get("table_name"), + row.get("label_name"), from_table, from_label, + row.get("from_partition_key_columns"), + row.get("from_clustering_columns"), to_table, to_label, + row.get("to_partition_key_columns"), + row.get("to_clustering_columns")) + + def _query_all(self): + cl = ConsistencyLevel.ONE + queries = [ + # copied from v4 + QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), + QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), + QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), + QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), + QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), + QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_KEYSPACES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_TABLES, consistency_level=cl), + QueryMessage(query=self._SELECT_VIRTUAL_COLUMNS, consistency_level=cl), + # dse6.8 only + QueryMessage(query=self._SELECT_VERTICES, consistency_level=cl), + QueryMessage(query=self._SELECT_EDGES, consistency_level=cl) + ] + + responses = self.connection.wait_for_responses( + *queries, timeout=self.timeout, fail_on_error=False) + ( + # copied from V4 + (ks_success, ks_result), + (table_success, table_result), + (col_success, col_result), + (types_success, types_result), + (functions_success, functions_result), + (aggregates_success, aggregates_result), + (triggers_success, triggers_result), + (indexes_success, indexes_result), + (views_success, views_result), + (virtual_ks_success, virtual_ks_result), + (virtual_table_success, virtual_table_result), + (virtual_column_success, virtual_column_result), + # dse6.8 responses + (vertices_success, vertices_result), + (edges_success, edges_result) + ) = responses + + # copied from V4 + self.keyspaces_result = self._handle_results(ks_success, ks_result) + self.tables_result = self._handle_results(table_success, table_result) + self.columns_result = self._handle_results(col_success, col_result) + self.triggers_result = self._handle_results(triggers_success, triggers_result) + self.types_result = self._handle_results(types_success, types_result) + self.functions_result = self._handle_results(functions_success, functions_result) + self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) + self.indexes_result = self._handle_results(indexes_success, indexes_result) + self.views_result = self._handle_results(views_success, views_result) + + # These tables don't exist in some DSE versions reporting 4.X so we can + # ignore them if we got an error + self.virtual_keyspaces_result = self._handle_results( + virtual_ks_success, virtual_ks_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_tables_result = self._handle_results( + virtual_table_success, virtual_table_result, + expected_failures=(InvalidRequest,) + ) + self.virtual_columns_result = self._handle_results( + virtual_column_success, virtual_column_result, + expected_failures=(InvalidRequest,) + ) + + # dse6.8-only results + self.vertices_result = self._handle_results(vertices_success, vertices_result) + self.edges_result = self._handle_results(edges_success, edges_result) + + self._aggregate_results() + + def _aggregate_results(self): + super(SchemaParserDSE68, self)._aggregate_results() + + m = self.keyspace_table_vertex_rows + for row in self.vertices_result: + ksname = row["keyspace_name"] + cfname = row['table_name'] + m[ksname][cfname].append(row) + + m = self.keyspace_table_edge_rows + for row in self.edges_result: + ksname = row["keyspace_name"] + cfname = row['table_name'] + m[ksname][cfname].append(row) class MaterializedViewMetadata(object): @@ -2364,8 +3101,7 @@ class MaterializedViewMetadata(object): """ keyspace_name = None - - """ A string name of the view.""" + """ A string name of the keyspace of this view.""" name = None """ A string name of the view.""" @@ -2406,6 +3142,11 @@ class MaterializedViewMetadata(object): view. """ + extensions = None + """ + Metadata describing configuration for table extensions + """ + def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): self.keyspace_name = keyspace_name self.name = view_name @@ -2442,26 +3183,123 @@ def as_cql_query(self, formatted=False): properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) - return "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \ - "SELECT %(selected_cols)s%(sep)s" \ - "FROM %(keyspace)s.%(base_table)s%(sep)s" \ - "WHERE %(where_clause)s%(sep)s" \ - "PRIMARY KEY %(pk)s%(sep)s" \ - "WITH %(properties)s" % locals() + ret = ("CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" + "SELECT %(selected_cols)s%(sep)s" + "FROM %(keyspace)s.%(base_table)s%(sep)s" + "WHERE %(where_clause)s%(sep)s" + "PRIMARY KEY %(pk)s%(sep)s" + "WITH %(properties)s") % locals() + + if self.extensions: + registry = _RegisteredExtensionType._extension_registry + for k in registry.keys() & self.extensions: # no viewkeys on OrderedMapSerializeKey + ext = registry[k] + cql = ext.after_table_cql(self, k, self.extensions[k]) + if cql: + ret += "\n\n%s" % (cql,) + return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ";" -def get_schema_parser(connection, timeout): - server_version = connection.server_version - if server_version.startswith('3'): +class VertexMetadata(object): + """ + A representation of a vertex on a table + """ + + keyspace_name = None + """ A string name of the keyspace. """ + + table_name = None + """ A string name of the table this vertex is on. """ + + label_name = None + """ A string name of the label of this vertex.""" + + def __init__(self, keyspace_name, table_name, label_name): + self.keyspace_name = keyspace_name + self.table_name = table_name + self.label_name = label_name + + +class EdgeMetadata(object): + """ + A representation of an edge on a table + """ + + keyspace_name = None + """A string name of the keyspace """ + + table_name = None + """A string name of the table this edge is on""" + + label_name = None + """A string name of the label of this edge""" + + from_table = None + """A string name of the from table of this edge (incoming vertex)""" + + from_label = None + """A string name of the from table label of this edge (incoming vertex)""" + + from_partition_key_columns = None + """The columns that match the partition key of the incoming vertex table.""" + + from_clustering_columns = None + """The columns that match the clustering columns of the incoming vertex table.""" + + to_table = None + """A string name of the to table of this edge (outgoing vertex)""" + + to_label = None + """A string name of the to table label of this edge (outgoing vertex)""" + + to_partition_key_columns = None + """The columns that match the partition key of the outgoing vertex table.""" + + to_clustering_columns = None + """The columns that match the clustering columns of the outgoing vertex table.""" + + def __init__( + self, keyspace_name, table_name, label_name, from_table, + from_label, from_partition_key_columns, from_clustering_columns, + to_table, to_label, to_partition_key_columns, + to_clustering_columns): + self.keyspace_name = keyspace_name + self.table_name = table_name + self.label_name = label_name + self.from_table = from_table + self.from_label = from_label + self.from_partition_key_columns = from_partition_key_columns + self.from_clustering_columns = from_clustering_columns + self.to_table = to_table + self.to_label = to_label + self.to_partition_key_columns = to_partition_key_columns + self.to_clustering_columns = to_clustering_columns + + +def get_schema_parser(connection, server_version, dse_version, timeout): + version = Version(server_version) + if dse_version: + v = Version(dse_version) + if v >= Version('6.8.0'): + return SchemaParserDSE68(connection, timeout) + elif v >= Version('6.7.0'): + return SchemaParserDSE67(connection, timeout) + elif v >= Version('6.0.0'): + return SchemaParserDSE60(connection, timeout) + + if version >= Version('4-a'): + return SchemaParserV4(connection, timeout) + elif version >= Version('3.0.0'): return SchemaParserV3(connection, timeout) else: # we could further specialize by version. Right now just refactoring the # multi-version parser we have as of C* 2.2.0rc1. return SchemaParserV22(connection, timeout) + def _cql_from_cass_type(cass_type): """ A string representation of the type for this column, such as "varchar" @@ -2471,3 +3309,106 @@ def _cql_from_cass_type(cass_type): return cass_type.subtypes[0].cql_parameterized_type() else: return cass_type.cql_parameterized_type() + + +class RLACTableExtension(RegisteredTableExtension): + name = "DSE_RLACA" + + @classmethod + def after_table_cql(cls, table_meta, ext_key, ext_blob): + return "RESTRICT ROWS ON %s.%s USING %s;" % (protect_name(table_meta.keyspace_name), + protect_name(table_meta.name), + protect_name(ext_blob.decode('utf-8'))) +NO_VALID_REPLICA = object() + + +def group_keys_by_replica(session, keyspace, table, keys): + """ + Returns a :class:`dict` with the keys grouped per host. This can be + used to more accurately group by IN clause or to batch the keys per host. + + If a valid replica is not found for a particular key it will be grouped under + :class:`~.NO_VALID_REPLICA` + + Example usage:: + result = group_keys_by_replica( + session, "system", "peers", + (("127.0.0.1", ), ("127.0.0.2", )) + ) + """ + cluster = session.cluster + + partition_keys = cluster.metadata.keyspaces[keyspace].tables[table].partition_key + + serializers = list(types._cqltypes[partition_key.cql_type] for partition_key in partition_keys) + keys_per_host = defaultdict(list) + distance = cluster._default_load_balancing_policy.distance + + for key in keys: + serialized_key = [serializer.serialize(pk, cluster.protocol_version) + for serializer, pk in zip(serializers, key)] + if len(serialized_key) == 1: + routing_key = serialized_key[0] + else: + routing_key = b"".join(struct.pack(">H%dsB" % len(p), len(p), p, 0) for p in serialized_key) + all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) + # First check if there are local replicas + valid_replicas = [host for host in all_replicas if + host.is_up and distance(host) == HostDistance.LOCAL] + if not valid_replicas: + valid_replicas = [host for host in all_replicas if host.is_up] + + if valid_replicas: + keys_per_host[random.choice(valid_replicas)].append(key) + else: + # We will group under this statement all the keys for which + # we haven't found a valid replica + keys_per_host[NO_VALID_REPLICA].append(key) + + return dict(keys_per_host) + + +# TODO next major reorg +class _NodeInfo(object): + """ + Internal utility functions to determine the different host addresses/ports + from a local or peers row. + """ + + @staticmethod + def get_broadcast_rpc_address(row): + # TODO next major, change the parsing logic to avoid any + # overriding of a non-null value + addr = row.get("rpc_address") + if "native_address" in row: + addr = row.get("native_address") + if "native_transport_address" in row: + addr = row.get("native_transport_address") + if not addr or addr in ["0.0.0.0", "::"]: + addr = row.get("peer") + + return addr + + @staticmethod + def get_broadcast_rpc_port(row): + port = row.get("rpc_port") + if port is None or port == 0: + port = row.get("native_port") + + return port if port and port > 0 else None + + @staticmethod + def get_broadcast_address(row): + addr = row.get("broadcast_address") + if addr is None: + addr = row.get("peer") + + return addr + + @staticmethod + def get_broadcast_port(row): + port = row.get("broadcast_port") + if port is None or port == 0: + port = row.get("peer_port") + + return port if port and port > 0 else None diff --git a/cassandra/metrics.py b/cassandra/metrics.py index cf1f25c15d..a1eadc1fc4 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +22,7 @@ except ImportError: raise ImportError( "The scales library is required for metrics support: " - "https://pypi.python.org/pypi/scales") + "https://pypi.org/project/scales/") log = logging.getLogger(__name__) @@ -28,6 +30,8 @@ class Metrics(object): """ A collection of timers and counters for various performance metrics. + + Timer metrics are represented as floating point seconds. """ request_timer = None @@ -35,17 +39,17 @@ class Metrics(object): A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like object with the following keys: - * count - number of requests that have been timed - * min - min latency - * max - max latency - * mean - mean latency - * stdev - standard deviation for latencies - * median - median latency - * 75percentile - 75th percentile latencies - * 97percentile - 97th percentile latencies - * 98percentile - 98th percentile latencies - * 99percentile - 99th percentile latencies - * 999percentile - 99.9th percentile latencies + * count - number of requests that have been timed + * min - min latency + * max - max latency + * mean - mean latency + * stddev - standard deviation for latencies + * median - median latency + * 75percentile - 75th percentile latencies + * 95percentile - 95th percentile latencies + * 98percentile - 98th percentile latencies + * 99percentile - 99th percentile latencies + * 999percentile - 99.9th percentile latencies """ connection_errors = None @@ -111,10 +115,14 @@ class Metrics(object): the driver currently has open. """ + _stats_counter = 0 + def __init__(self, cluster_proxy): log.debug("Starting metric capture") - self.stats = scales.collection('/cassandra', + self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) + Metrics._stats_counter += 1 + self.stats = scales.collection(self.stats_name, scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), @@ -132,6 +140,11 @@ def __init__(self, cluster_proxy): scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) + # TODO, to be removed in 4.0 + # /cassandra contains the metrics of the first cluster registered + if 'cassandra' not in scales._Stats.stats: + scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] + self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts @@ -164,3 +177,27 @@ def on_ignore(self): def on_retry(self): self.stats.retries += 1 + + def get_stats(self): + """ + Returns the metrics for the registered cluster instance. + """ + return scales.getStats()[self.stats_name] + + def set_stats_name(self, stats_name): + """ + Set the metrics stats name. + The stats_name is a string used to access the metrics through scales: scales.getStats()[] + Default is 'cassandra-'. + """ + + if self.stats_name == stats_name: + return + + if stats_name in scales._Stats.stats: + raise ValueError('"{0}" already exists in stats.'.format(stats_name)) + + stats = scales._Stats.stats[self.stats_name] + del scales._Stats.stats[self.stats_name] + self.stats_name = stats_name + scales._Stats.stats[self.stats_name] = stats diff --git a/cassandra/murmur3.py b/cassandra/murmur3.py index 61180c0121..282c43578d 100644 --- a/cassandra/murmur3.py +++ b/cassandra/murmur3.py @@ -1,4 +1,3 @@ -from six.moves import range import struct @@ -7,7 +6,9 @@ def body_and_tail(data): nblocks = l // 16 tail = l % 16 if nblocks: - return struct.unpack_from('qq' * nblocks, data), struct.unpack_from('b' * tail, data, -tail), l + # we use '<', specifying little-endian byte order for data bigger than + # a byte so behavior is the same on little- and big-endian platforms + return struct.unpack_from('<' + ('qq' * nblocks), data), struct.unpack_from('b' * tail, data, -tail), l else: return tuple(), struct.unpack_from('b' * tail, data, -tail), l diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 920a3efdd1..2377258b36 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +15,7 @@ # limitations under the License. """ -This module provider an optional protocol parser that returns +This module provides an optional protocol parser that returns NumPy arrays. ============================================================================= @@ -25,7 +27,7 @@ as numpy is an optional dependency. include "ioutils.pyx" cimport cython -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, uint8_t from cpython.ref cimport Py_INCREF, PyObject from cassandra.bytesio cimport BytesIOReader @@ -35,7 +37,6 @@ from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np -# import pandas as pd cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the @@ -52,12 +53,14 @@ ctypedef struct ArrDesc: Py_uintptr_t buf_ptr int stride # should be large enough as we allocate contiguous arrays int is_object + Py_uintptr_t mask_ptr arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) - ]) + , ('mask_ptr', np.uintp) + ], align=True) _cqltype_to_numpy = { cqltypes.LongType: np.dtype('>i8'), @@ -70,6 +73,7 @@ _cqltype_to_numpy = { obj_dtype = np.dtype('O') +cdef uint8_t mask_true = 0x01 cdef class NumpyParser(ColumnParser): """Decode a ResultMessage into a bunch of NumPy arrays""" @@ -116,7 +120,11 @@ def make_arrays(ParseDesc desc, array_size): arr = make_array(coltype, array_size) array_descs[i]['buf_ptr'] = arr.ctypes.data array_descs[i]['stride'] = arr.strides[0] - array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy + array_descs[i]['is_object'] = arr.dtype is obj_dtype + try: + array_descs[i]['mask_ptr'] = arr.mask.ctypes.data + except AttributeError: + array_descs[i]['mask_ptr'] = 0 arrays.append(arr) return array_descs, arrays @@ -126,8 +134,12 @@ def make_array(coltype, array_size): """ Allocate a new NumPy array of the given column type and size. """ - dtype = _cqltype_to_numpy.get(coltype, obj_dtype) - return np.empty((array_size,), dtype=dtype) + try: + a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) + a.mask = np.zeros((array_size,), dtype=bool) + except KeyError: + a = np.empty((array_size,), dtype=obj_dtype) + return a #### Parse rows into NumPy arrays @@ -140,23 +152,23 @@ cdef inline int unpack_row( cdef Py_ssize_t i, rowsize = desc.rowsize cdef ArrDesc arr cdef Deserializer deserializer - for i in range(rowsize): get_buf(reader, &buf) arr = arrays[i] - if buf.size == 0: - raise ValueError("Cannot handle NULL value") if arr.is_object: deserializer = desc.deserializers[i] val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val - else: + elif buf.size >= 0: memcpy( arr.buf_ptr, buf.ptr, buf.size) + else: + memcpy(arr.mask_ptr, &mask_true, 1) # Update the pointer into the array for the next time arrays[i].buf_ptr += arr.stride + arrays[i].mask_ptr += 1 return 0 diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index 2ec889ebc6..f1bfb551ef 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,11 +16,15 @@ include "ioutils.pyx" +from cassandra import DriverException from cassandra.bytesio cimport BytesIOReader from cassandra.deserializers cimport Deserializer, from_binary +from cassandra.deserializers import find_deserializer from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser from cassandra.tuple cimport tuple_new, tuple_set +from cpython.bytes cimport PyBytes_AsStringAndSize + cdef class ListParser(ColumnParser): """Decode a ResultMessage into a list of tuples (or other objects)""" @@ -57,18 +63,33 @@ cdef class TupleRowParser(RowParser): assert desc.rowsize >= 0 cdef Buffer buf + cdef Buffer newbuf cdef Py_ssize_t i, rowsize = desc.rowsize cdef Deserializer deserializer cdef tuple res = tuple_new(desc.rowsize) + ce_policy = desc.column_encryption_policy for i in range(rowsize): # Read the next few bytes get_buf(reader, &buf) # Deserialize bytes to python object deserializer = desc.deserializers[i] - val = from_binary(deserializer, &buf, desc.protocol_version) - + coldesc = desc.coldescs[i] + uses_ce = ce_policy and ce_policy.contains_column(coldesc) + try: + if uses_ce: + col_type = ce_policy.column_type(coldesc) + decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) + PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size) + deserializer = find_deserializer(ce_policy.column_type(coldesc)) + val = from_binary(deserializer, &newbuf, desc.protocol_version) + else: + val = from_binary(deserializer, &buf, desc.protocol_version) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) # Insert new object into tuple tuple_set(res, i, val) diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd index abfc74d12d..1b3ed3dcbf 100644 --- a/cassandra/parsing.pxd +++ b/cassandra/parsing.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +20,8 @@ from cassandra.deserializers cimport Deserializer cdef class ParseDesc: cdef public object colnames cdef public object coltypes + cdef public object column_encryption_policy + cdef public list coldescs cdef Deserializer[::1] deserializers cdef public int protocol_version cdef Py_ssize_t rowsize diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx index 06cfe0bb8f..085544a362 100644 --- a/cassandra/parsing.pyx +++ b/cassandra/parsing.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,9 +21,11 @@ Module containing the definitions and declarations (parsing.pxd) for parsers. cdef class ParseDesc: """Description of what structure to parse""" - def __init__(self, colnames, coltypes, deserializers, protocol_version): + def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version): self.colnames = colnames self.coltypes = coltypes + self.column_encryption_policy = column_encryption_policy + self.coldescs = coldescs self.deserializers = deserializers self.protocol_version = protocol_version self.rowsize = len(colnames) diff --git a/cassandra/policies.py b/cassandra/policies.py index 595717ca5a..d6f7063e7a 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,18 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import namedtuple +from functools import lru_cache from itertools import islice, cycle, groupby, repeat import logging -from random import randint +from random import randint, shuffle from threading import Lock -import six +import socket +import warnings -from cassandra import ConsistencyLevel +log = logging.getLogger(__name__) -from six.moves import range +from cassandra import WriteType as WT -log = logging.getLogger(__name__) +# This is done this way because WriteType was originally +# defined here and in order not to break the API. +# It may be removed in the next major. +WriteType = WT +from cassandra import ConsistencyLevel, OperationTimedOut class HostDistance(object): """ @@ -119,7 +128,7 @@ def populate(self, cluster, hosts): def make_query_plan(self, working_keyspace=None, query=None): """ - Given a :class:`~.query.Statement` instance, return a iterable + Given a :class:`~.query.Statement` instance, return an iterable of :class:`.Host` instances which should be queried in that order. A generator may work well for custom implementations of this method. @@ -148,16 +157,13 @@ class RoundRobinPolicy(LoadBalancingPolicy): A subclass of :class:`.LoadBalancingPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in. - - This load balancing policy is used by default. """ _live_hosts = frozenset(()) + _position = 0 def populate(self, cluster, hosts): self._live_hosts = frozenset(hosts) - if len(hosts) <= 1: - self._position = 0 - else: + if len(hosts) > 1: self._position = randint(0, len(hosts) - 1) def distance(self, host): @@ -224,7 +230,7 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0): self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} self._position = 0 - self._contact_points = [] + self._endpoints = [] LoadBalancingPolicy.__init__(self) def _dc(self, host): @@ -235,7 +241,9 @@ def populate(self, cluster, hosts): self._dc_live_hosts[dc] = tuple(set(dc_hosts)) if not self.local_dc: - self._contact_points = cluster.contact_points + self._endpoints = [ + endpoint + for endpoint in cluster.endpoints_resolved] self._position = randint(0, len(hosts) - 1) if hosts else 0 @@ -278,13 +286,13 @@ def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh if not self.local_dc and host.datacenter: - if host.address in self._contact_points: + if host.endpoint in self._endpoints: self.local_dc = host.datacenter log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " "if incorrect, please specify a local_dc to the constructor, " "or limit contact points to local cluster nodes" % - (self.local_dc, host.address)) - del self._contact_points + (self.local_dc, host.endpoint)) + del self._endpoints dc = self._dc(host) with self._hosts_lock: @@ -318,8 +326,10 @@ class TokenAwarePolicy(LoadBalancingPolicy): This alters the child policy's behavior so that it first attempts to send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined by the child policy) based on the :class:`.Statement`'s - :attr:`~.Statement.routing_key`. Once those hosts are exhausted, the - remaining hosts in the child policy's query plan will be used. + :attr:`~.Statement.routing_key`. If :attr:`.shuffle_replicas` is + truthy, these replicas will be yielded in a random order. Once those + hosts are exhausted, the remaining hosts in the child policy's query + plan will be used in the order provided by the child policy. If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. @@ -327,9 +337,14 @@ class TokenAwarePolicy(LoadBalancingPolicy): _child_policy = None _cluster_metadata = None + shuffle_replicas = False + """ + Yield local replicas in a random order. + """ - def __init__(self, child_policy): + def __init__(self, child_policy, shuffle_replicas=False): self._child_policy = child_policy + self.shuffle_replicas = shuffle_replicas def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata @@ -337,7 +352,7 @@ def populate(self, cluster, hosts): def check_supported(self): if not self._cluster_metadata.can_support_partitioner(): - raise Exception( + raise RuntimeError( '%s cannot be used with the cluster partitioner (%s) because ' 'the relevant C extension for this driver was not compiled. ' 'See the installation instructions for details on building ' @@ -364,6 +379,8 @@ def make_query_plan(self, working_keyspace=None, query=None): yield host else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + if self.shuffle_replicas: + shuffle(replicas) for replica in replicas: if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: @@ -400,16 +417,20 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy): Where connection errors occur when connection attempts are made to private IP addresses remotely """ + def __init__(self, hosts): """ The `hosts` parameter should be a sequence of hosts to permit connections to. """ - self._allowed_hosts = hosts + self._allowed_hosts = tuple(hosts) + self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts + for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)] + RoundRobinPolicy.__init__(self) def populate(self, cluster, hosts): - self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts) + self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts_resolved) if len(hosts) <= 1: self._position = 0 @@ -417,20 +438,132 @@ def populate(self, cluster, hosts): self._position = randint(0, len(hosts) - 1) def distance(self, host): - if host.address in self._allowed_hosts: + if host.address in self._allowed_hosts_resolved: return HostDistance.LOCAL else: return HostDistance.IGNORED def on_up(self, host): - if host.address in self._allowed_hosts: + if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_up(self, host) def on_add(self, host): - if host.address in self._allowed_hosts: + if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_add(self, host) +class HostFilterPolicy(LoadBalancingPolicy): + """ + A :class:`.LoadBalancingPolicy` subclass configured with a child policy, + and a single-argument predicate. This policy defers to the child policy for + hosts where ``predicate(host)`` is truthy. Hosts for which + ``predicate(host)`` is falsy will be considered :attr:`.IGNORED`, and will + not be used in a query plan. + + This can be used in the cases where you need a whitelist or blacklist + policy, e.g. to prepare for decommissioning nodes or for testing: + + .. code-block:: python + + def address_is_ignored(host): + return host.address in [ignored_address0, ignored_address1] + + blacklist_filter_policy = HostFilterPolicy( + child_policy=RoundRobinPolicy(), + predicate=address_is_ignored + ) + + cluster = Cluster( + primary_host, + load_balancing_policy=blacklist_filter_policy, + ) + + See the note in the :meth:`.make_query_plan` documentation for a caveat on + how wrapping ordering polices (e.g. :class:`.RoundRobinPolicy`) may break + desirable properties of the wrapped policy. + + Please note that whitelist and blacklist policies are not recommended for + general, day-to-day use. You probably want something like + :class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has + fallbacks, over a brute-force method like whitelisting or blacklisting. + """ + + def __init__(self, child_policy, predicate): + """ + :param child_policy: an instantiated :class:`.LoadBalancingPolicy` + that this one will defer to. + :param predicate: a one-parameter function that takes a :class:`.Host`. + If it returns a falsy value, the :class:`.Host` will + be :attr:`.IGNORED` and not returned in query plans. + """ + super(HostFilterPolicy, self).__init__() + self._child_policy = child_policy + self._predicate = predicate + + def on_up(self, host, *args, **kwargs): + return self._child_policy.on_up(host, *args, **kwargs) + + def on_down(self, host, *args, **kwargs): + return self._child_policy.on_down(host, *args, **kwargs) + + def on_add(self, host, *args, **kwargs): + return self._child_policy.on_add(host, *args, **kwargs) + + def on_remove(self, host, *args, **kwargs): + return self._child_policy.on_remove(host, *args, **kwargs) + + @property + def predicate(self): + """ + A predicate, set on object initialization, that takes a :class:`.Host` + and returns a value. If the value is falsy, the :class:`.Host` is + :class:`~HostDistance.IGNORED`. If the value is truthy, + :class:`.HostFilterPolicy` defers to the child policy to determine the + host's distance. + + This is a read-only value set in ``__init__``, implemented as a + ``property``. + """ + return self._predicate + + def distance(self, host): + """ + Checks if ``predicate(host)``, then returns + :attr:`~HostDistance.IGNORED` if falsy, and defers to the child policy + otherwise. + """ + if self.predicate(host): + return self._child_policy.distance(host) + else: + return HostDistance.IGNORED + + def populate(self, cluster, hosts): + self._child_policy.populate(cluster=cluster, hosts=hosts) + + def make_query_plan(self, working_keyspace=None, query=None): + """ + Defers to the child policy's + :meth:`.LoadBalancingPolicy.make_query_plan` and filters the results. + + Note that this filtering may break desirable properties of the wrapped + policy in some cases. For instance, imagine if you configure this + policy to filter out ``host2``, and to wrap a round-robin policy that + rotates through three hosts in the order ``host1, host2, host3``, + ``host2, host3, host1``, ``host3, host1, host2``, repeating. This + policy will yield ``host1, host3``, ``host3, host1``, ``host3, host1``, + disproportionately favoring ``host3``. + """ + child_qp = self._child_policy.make_query_plan( + working_keyspace=working_keyspace, query=query + ) + for host in child_qp: + if self.predicate(host): + yield host + + def check_supported(self): + return self._child_policy.check_supported() + + class ConvictionPolicy(object): """ A policy which decides when hosts should be considered down @@ -468,7 +601,7 @@ class SimpleConvictionPolicy(ConvictionPolicy): """ def add_failure(self, connection_exc): - return True + return not isinstance(connection_exc, OperationTimedOut) def reset(self): pass @@ -485,7 +618,7 @@ class ReconnectionPolicy(object): def new_schedule(self): """ This should return a finite or infinite iterable of delays (each as a - floating point number of seconds) inbetween each failed reconnection + floating point number of seconds) in-between each failed reconnection attempt. Note that if the iterable is finite, reconnection attempts will cease once the iterable is exhausted. """ @@ -495,12 +628,12 @@ def new_schedule(self): class ConstantReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay - inbetween each reconnection attempt. + in-between each reconnection attempt. """ def __init__(self, delay, max_attempts=64): """ - `delay` should be a floating point number of seconds to wait inbetween + `delay` should be a floating point number of seconds to wait in-between each attempt. `max_attempts` should be a total number of attempts to be made before @@ -524,8 +657,12 @@ def new_schedule(self): class ExponentialReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which exponentially increases - the length of the delay inbetween each reconnection attempt up to + the length of the delay in-between each reconnection attempt up to a set maximum delay. + + A random amount of jitter (+/- 15%) will be added to the pure exponential + delay value to avoid the situations where many reconnection handlers are + trying to reconnect at exactly the same time. """ # TODO: max_attempts is 64 to preserve legacy default behavior @@ -554,61 +691,24 @@ def __init__(self, base_delay, max_delay, max_attempts=64): self.max_attempts = max_attempts def new_schedule(self): - i=0 - while self.max_attempts == None or i < self.max_attempts: - yield min(self.base_delay * (2 ** i), self.max_delay) - i += 1 - - -class WriteType(object): - """ - For usage with :class:`.RetryPolicy`, this describe a type - of write operation. - """ - - SIMPLE = 0 - """ - A write to a single partition key. Such writes are guaranteed to be atomic - and isolated. - """ - - BATCH = 1 - """ - A write to multiple partition keys that used the distributed batch log to - ensure atomicity. - """ - - UNLOGGED_BATCH = 2 - """ - A write to multiple partition keys that did not use the distributed batch - log. Atomicity for such writes is not guaranteed. - """ - - COUNTER = 3 - """ - A counter write (for one or multiple partition keys). Such writes should - not be replayed in order to avoid overcount. - """ - - BATCH_LOG = 4 - """ - The initial write to the distributed batch log that Cassandra performs - internally before a BATCH write. - """ + i, overflowed = 0, False + while self.max_attempts is None or i < self.max_attempts: + if overflowed: + yield self.max_delay + else: + try: + yield self._add_jitter(min(self.base_delay * (2 ** i), self.max_delay)) + except OverflowError: + overflowed = True + yield self.max_delay - CAS = 5 - """ - A lighweight-transaction write, such as "DELETE ... IF EXISTS". - """ + i += 1 -WriteType.name_to_value = { - 'SIMPLE': WriteType.SIMPLE, - 'BATCH': WriteType.BATCH, - 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, - 'COUNTER': WriteType.COUNTER, - 'BATCH_LOG': WriteType.BATCH_LOG, - 'CAS': WriteType.CAS -} + # Adds -+ 15% to the delay provided + def _add_jitter(self, value): + jitter = randint(85, 115) + delay = (jitter * value) / 100 + return min(max(self.base_delay, delay), self.max_delay) class RetryPolicy(object): @@ -617,7 +717,7 @@ class RetryPolicy(object): timeout and unavailable failures. These are failures reported from the server side. Timeouts are configured by `settings in cassandra.yaml `_. - Unavailable failures occur when the coordinator cannot acheive the consistency + Unavailable failures occur when the coordinator cannot achieve the consistency level for a request. For further information see the method descriptions below. @@ -650,6 +750,12 @@ class or one of its subclasses. should be ignored but no more retries should be attempted. """ + RETRY_NEXT_HOST = 3 + """ + This should be returned from the below methods if the operation + should be retried on another connection. + """ + def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): """ @@ -677,11 +783,11 @@ def on_read_timeout(self, query, consistency, required_responses, a sufficient number of replicas responded (with data digests). """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif received_responses >= required_responses and not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): @@ -705,23 +811,23 @@ def on_write_timeout(self, query, consistency, write_type, `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. - By default, failed write operations will retried at most once, and - they will only be retried if the `write_type` was + By default, a failed write operations will be retried at most once, and + will only be retried if the `write_type` was :attr:`~.WriteType.BATCH_LOG`. """ if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): """ This is called when the coordinator node determines that a read or write operation cannot be successful because the number of live replicas are too low to meet the requested :class:`.ConsistencyLevel`. - This means that the read or write operation was never forwared to + This means that the read or write operation was never forwarded to any replicas. `query` is the :class:`.Statement` that failed. @@ -737,9 +843,36 @@ def on_unavailable(self, query, consistency, required_replicas, alive_replicas, `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. - By default, no retries will be attempted and the error will be re-raised. + By default, if this is the first retry, it triggers a retry on the next + host in the query plan with the same consistency level. If this is not the + first retry, no retries will be attempted and the error will be re-raised. """ - return (self.RETHROW, None) + return (self.RETRY_NEXT_HOST, None) if retry_num == 0 else (self.RETHROW, None) + + def on_request_error(self, query, consistency, error, retry_num): + """ + This is called when an unexpected error happens. This can be in the + following situations: + + * On a connection error + * On server errors: overloaded, isBootstrapping, serverError, etc. + + `query` is the :class:`.Statement` that timed out. + + `consistency` is the :class:`.ConsistencyLevel` that the operation was + attempted at. + + `error` the instance of the exception. + + `retry_num` counts how many times the operation has been retried, so + the first time this method is called, `retry_num` will be 0. + + By default, it triggers a retry on the next host in the query plan + with the same consistency level. + """ + # TODO revisit this for the next major + # To preserve the same behavior than before, we don't take retry_num into account + return self.RETRY_NEXT_HOST, None class FallthroughRetryPolicy(RetryPolicy): @@ -749,17 +882,22 @@ class FallthroughRetryPolicy(RetryPolicy): """ def on_read_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, *args, **kwargs): - return (self.RETHROW, None) + return self.RETHROW, None + + def on_request_error(self, *args, **kwargs): + return self.RETHROW, None class DowngradingConsistencyRetryPolicy(RetryPolicy): """ + *Deprecated:* This retry policy will be removed in the next major release. + A retry policy that sometimes retries with a lower consistency level than the one initially requested. @@ -771,7 +909,7 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): policy unless you have understood the cases where this can happen and are ok with that. It is also recommended to subclass this class so that queries that required a consistency level downgrade can be - recorded (so that repairs can be made later, etc). + recorded (so that repairs can be made later, etc.). This policy implements the same retries as :class:`.RetryPolicy`, but on top of that, it also retries in the following cases: @@ -805,47 +943,302 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): to make sure the data is persisted, and that reading something is better than reading nothing, even if there is a risk of reading stale data. """ + def __init__(self, *args, **kwargs): + super(DowngradingConsistencyRetryPolicy, self).__init__(*args, **kwargs) + warnings.warn('DowngradingConsistencyRetryPolicy is deprecated ' + 'and will be removed in the next major release.', + DeprecationWarning) + def _pick_consistency(self, num_responses): if num_responses >= 3: - return (self.RETRY, ConsistencyLevel.THREE) + return self.RETRY, ConsistencyLevel.THREE elif num_responses >= 2: - return (self.RETRY, ConsistencyLevel.TWO) + return self.RETRY, ConsistencyLevel.TWO elif num_responses >= 1: - return (self.RETRY, ConsistencyLevel.ONE) + return self.RETRY, ConsistencyLevel.ONE else: - return (self.RETHROW, None) + return self.RETHROW, None def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # Downgrading does not make sense for a CAS read query + return self.RETHROW, None elif received_responses < required_responses: return self._pick_consistency(received_responses) elif not data_retrieved: - return (self.RETRY, consistency) + return self.RETRY, consistency else: - return (self.RETHROW, None) + return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): if received_responses > 0: # persisted on at least one replica - return (self.IGNORE, None) + return self.IGNORE, None else: - return (self.RETHROW, None) + return self.RETHROW, None elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: - return (self.RETRY, consistency) + return self.RETRY, consistency - return (self.RETHROW, None) + return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): if retry_num != 0: - return (self.RETHROW, None) + return self.RETHROW, None + elif ConsistencyLevel.is_serial(consistency): + # failed at the paxos phase of a LWT, retry on the next host + return self.RETRY_NEXT_HOST, None else: return self._pick_consistency(alive_replicas) + + +class AddressTranslator(object): + """ + Interface for translating cluster-defined endpoints. + + The driver discovers nodes using server metadata and topology change events. Normally, + the endpoint defined by the server is the right way to connect to a node. In some environments, + these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments, + suboptimal routing, etc.). This interface allows for translating from server defined endpoints to + preferred addresses for driver connections. + + *Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not + translated using this mechanism -- only addresses received from Cassandra nodes are. + """ + def translate(self, addr): + """ + Accepts the node ip address, and returns a translated address to be used connecting to this node. + """ + raise NotImplementedError() + + +class IdentityTranslator(AddressTranslator): + """ + Returns the endpoint with no translation + """ + def translate(self, addr): + return addr + + +class EC2MultiRegionTranslator(AddressTranslator): + """ + Resolves private ips of the hosts in the same datacenter as the client, and public ips of hosts in other datacenters. + """ + def translate(self, addr): + """ + Reverse DNS the public broadcast_address, then lookup that hostname to get the AWS-resolved IP, which + will point to the private IP address within the same datacenter. + """ + # get family of this address, so we translate to the same + family = socket.getaddrinfo(addr, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)[0][0] + host = socket.getfqdn(addr) + for a in socket.getaddrinfo(host, 0, family, socket.SOCK_STREAM): + try: + return a[4][0] + except Exception: + pass + return addr + + +class SpeculativeExecutionPolicy(object): + """ + Interface for specifying speculative execution plans + """ + + def new_plan(self, keyspace, statement): + """ + Returns + + :param keyspace: + :param statement: + :return: + """ + raise NotImplementedError() + + +class SpeculativeExecutionPlan(object): + def next_execution(self, host): + raise NotImplementedError() + + +class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def next_execution(self, host): + return -1 + + +class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + + def new_plan(self, keyspace, statement): + return NoSpeculativeExecutionPlan() + + +class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): + """ + A speculative execution policy that sends a new query every X seconds (**delay**) for a maximum of Y attempts (**max_attempts**). + """ + + def __init__(self, delay, max_attempts): + self.delay = delay + self.max_attempts = max_attempts + + class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan): + def __init__(self, delay, max_attempts): + self.delay = delay + self.remaining = max_attempts + + def next_execution(self, host): + if self.remaining > 0: + self.remaining -= 1 + return self.delay + else: + return -1 + + def new_plan(self, keyspace, statement): + return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts) + + +class WrapperPolicy(LoadBalancingPolicy): + + def __init__(self, child_policy): + self._child_policy = child_policy + + def distance(self, *args, **kwargs): + return self._child_policy.distance(*args, **kwargs) + + def populate(self, cluster, hosts): + self._child_policy.populate(cluster, hosts) + + def on_up(self, *args, **kwargs): + return self._child_policy.on_up(*args, **kwargs) + + def on_down(self, *args, **kwargs): + return self._child_policy.on_down(*args, **kwargs) + + def on_add(self, *args, **kwargs): + return self._child_policy.on_add(*args, **kwargs) + + def on_remove(self, *args, **kwargs): + return self._child_policy.on_remove(*args, **kwargs) + + +class DefaultLoadBalancingPolicy(WrapperPolicy): + """ + A :class:`.LoadBalancingPolicy` wrapper that adds the ability to target a specific host first. + + If no host is set on the query, the child policy's query plan will be used as is. + """ + + _cluster_metadata = None + + def populate(self, cluster, hosts): + self._cluster_metadata = cluster.metadata + self._child_policy.populate(cluster, hosts) + + def make_query_plan(self, working_keyspace=None, query=None): + if query and query.keyspace: + keyspace = query.keyspace + else: + keyspace = working_keyspace + + # TODO remove next major since execute(..., host=XXX) is now available + addr = getattr(query, 'target_host', None) if query else None + target_host = self._cluster_metadata.get_host(addr) + + child = self._child_policy + if target_host and target_host.is_up: + yield target_host + for h in child.make_query_plan(keyspace, query): + if h != target_host: + yield h + else: + for h in child.make_query_plan(keyspace, query): + yield h + + +# TODO for backward compatibility, remove in next major +class DSELoadBalancingPolicy(DefaultLoadBalancingPolicy): + """ + *Deprecated:* This will be removed in the next major release, + consider using :class:`.DefaultLoadBalancingPolicy`. + """ + def __init__(self, *args, **kwargs): + super(DSELoadBalancingPolicy, self).__init__(*args, **kwargs) + warnings.warn("DSELoadBalancingPolicy will be removed in 4.0. Consider using " + "DefaultLoadBalancingPolicy.", DeprecationWarning) + + +class NeverRetryPolicy(RetryPolicy): + def _rethrow(self, *args, **kwargs): + return self.RETHROW, None + + on_read_timeout = _rethrow + on_write_timeout = _rethrow + on_unavailable = _rethrow + + +ColDesc = namedtuple('ColDesc', ['ks', 'table', 'col']) + +class ColumnEncryptionPolicy(object): + """ + A policy enabling (mostly) transparent encryption and decryption of data before it is + sent to the cluster. + + Key materials and other configurations are specified on a per-column basis. This policy can + then be used by driver structures which are aware of the underlying columns involved in their + work. In practice this includes the following cases: + + * Prepared statements - data for columns specified by the cluster's policy will be transparently + encrypted before they are sent + * Rows returned from any query - data for columns specified by the cluster's policy will be + transparently decrypted before they are returned to the user + + To enable this functionality, create an instance of this class (or more likely a subclass) + before creating a cluster. This policy should then be configured and supplied to the Cluster + at creation time via the :attr:`.Cluster.column_encryption_policy` attribute. + """ + + def encrypt(self, coldesc, obj_bytes): + """ + Encrypt the specified bytes using the cryptography materials for the specified column. + Largely used internally, although this could also be used to encrypt values supplied + to non-prepared statements in a way that is consistent with this policy. + """ + raise NotImplementedError() + + def decrypt(self, coldesc, encrypted_bytes): + """ + Decrypt the specified (encrypted) bytes using the cryptography materials for the + specified column. Used internally; could be used externally as well but there's + not currently an obvious use case. + """ + raise NotImplementedError() + + def add_column(self, coldesc, key): + """ + Provide cryptography materials to be used when encrypted and/or decrypting data + for the specified column. + """ + raise NotImplementedError() + + def contains_column(self, coldesc): + """ + Predicate to determine if a specific column is supported by this policy. + Currently only used internally. + """ + raise NotImplementedError() + + def encode_and_encrypt(self, coldesc, obj): + """ + Helper function to enable use of this policy on simple (i.e. non-prepared) + statements. + """ + raise NotImplementedError() diff --git a/cassandra/pool.py b/cassandra/pool.py index 08cfd37271..37fdaee96b 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,6 +18,7 @@ Connection pooling and host management. """ +from functools import total_ordering import logging import socket import time @@ -27,7 +30,7 @@ from cassandra.util import WeakSet # NOQA from cassandra import AuthenticationFailed -from cassandra.connection import ConnectionException +from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint from cassandra.policies import HostDistance log = logging.getLogger(__name__) @@ -41,14 +44,73 @@ class NoConnectionsAvailable(Exception): pass +@total_ordering class Host(object): """ Represents a single Cassandra node. """ - address = None + endpoint = None + """ + The :class:`~.connection.EndPoint` to connect to the node. + """ + + broadcast_address = None + """ + broadcast address configured for the node, *if available*: + + 'system.local.broadcast_address' or 'system.peers.peer' (Cassandra 2-3) + 'system.local.broadcast_address' or 'system.peers_v2.peer' (Cassandra 4) + + This is not present in the ``system.local`` table for older versions of Cassandra. It + is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + broadcast_port = None + """ + broadcast port configured for the node, *if available*: + + 'system.local.broadcast_port' or 'system.peers_v2.peer_port' (Cassandra 4) + + It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + broadcast_rpc_address = None + """ + The broadcast rpc address of the node: + + 'system.local.rpc_address' or 'system.peers.rpc_address' (Cassandra 3) + 'system.local.rpc_address' or 'system.peers.native_transport_address (DSE 6+)' + 'system.local.rpc_address' or 'system.peers_v2.native_address (Cassandra 4)' """ - The IP address or hostname of the node. + + broadcast_rpc_port = None + """ + The broadcast rpc port of the node, *if available*: + + 'system.local.rpc_port' or 'system.peers.native_transport_port' (DSE 6+) + 'system.local.rpc_port' or 'system.peers_v2.native_port' (Cassandra 4) + """ + + listen_address = None + """ + listen address configured for the node, *if available*: + + 'system.local.listen_address' + + This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not + queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address`` + unless configured differently in cassandra.yaml. + """ + + listen_port = None + """ + listen port configured for the node, *if available*: + + 'system.local.listen_port' + + This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not + queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ conviction_policy = None @@ -64,6 +126,37 @@ class Host(object): up or down. """ + release_version = None + """ + release_version as queried from the control connection system tables + """ + + host_id = None + """ + The unique identifier of the cassandra node + """ + + dse_version = None + """ + dse_version as queried from the control connection system tables. Only populated when connecting to + DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + + dse_workload = None + """ + DSE workload queried from the control connection system tables. Only populated when connecting to + DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + This is a legacy attribute that does not portray multiple workloads in a uniform fashion. + See also :attr:`~.Host.dse_workloads`. + """ + + dse_workloads = None + """ + DSE workloads set, queried from the control connection system tables. Only populated when connecting to + DSE with this property available (added in DSE 5.1). + Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. + """ + _datacenter = None _rack = None _reconnection_handler = None @@ -71,17 +164,26 @@ class Host(object): _currently_handling_node_up = False - def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None): - if inet_address is None: - raise ValueError("inet_address may not be None") + def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): + if endpoint is None: + raise ValueError("endpoint may not be None") if conviction_policy_factory is None: raise ValueError("conviction_policy_factory may not be None") - self.address = inet_address + self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) self.conviction_policy = conviction_policy_factory(self) + self.host_id = host_id self.set_location_info(datacenter, rack) self.lock = RLock() + @property + def address(self): + """ + The IP address of the endpoint. This is the RPC address the driver uses when connecting to the node. + """ + # backward compatibility + return self.endpoint.address + @property def datacenter(self): """ The datacenter the node is in. """ @@ -103,7 +205,7 @@ def set_location_info(self, datacenter, rack): def set_up(self): if not self.is_up: - log.debug("Host %s is now marked up", self.address) + log.debug("Host %s is now marked up", self.endpoint) self.conviction_policy.reset() self.is_up = True @@ -127,20 +229,23 @@ def get_and_set_reconnection_handler(self, new_handler): return old def __eq__(self, other): - return self.address == other.address + if isinstance(other, Host): + return self.endpoint == other.endpoint + else: # TODO Backward compatibility, remove next major + return self.endpoint.address == other def __hash__(self): - return hash(self.address) + return hash(self.endpoint) def __lt__(self, other): - return self.address < other.address + return self.endpoint < other.endpoint def __str__(self): - return str(self.address) + return str(self.endpoint) def __repr__(self): dc = (" %s" % (self._datacenter,)) if self._datacenter else "" - return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) + return "<%s: %s%s>" % (self.__class__.__name__, self.endpoint, dc) class _ReconnectionHandler(object): @@ -272,17 +377,25 @@ class HostConnection(object): host = None host_distance = None is_shutdown = False + shutdown_on_error = False _session = None _connection = None _lock = None + _keyspace = None def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() + # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. + self._stream_available_condition = Condition(self._lock) self._is_replacing = False + # Contains connections which shouldn't be used anymore + # and are waiting until all requests time out or complete + # so that we can dispose of them. + self._trash = set() if host_distance == HostDistance.IGNORED: log.debug("Not opening connection to ignored host %s", self.host) @@ -292,12 +405,13 @@ def __init__(self, host, host_distance, session): return log.debug("Initializing connection for host %s", self.host) - self._connection = session.cluster.connection_factory(host.address) - if session.keyspace: - self._connection.set_keyspace_blocking(session.keyspace) + self._connection = session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + self._keyspace = session.keyspace + if self._keyspace: + self._connection.set_keyspace_blocking(self._keyspace) log.debug("Finished initializing connection for host %s", self.host) - def borrow_connection(self, timeout): + def _get_connection(self): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) @@ -305,24 +419,62 @@ def borrow_connection(self, timeout): conn = self._connection if not conn: raise NoConnectionsAvailable() + return conn - with conn.lock: - if conn.in_flight < conn.max_request_id: - conn.in_flight += 1 - return conn, conn.get_request_id() + def borrow_connection(self, timeout): + conn = self._get_connection() + if conn.orphaned_threshold_reached: + with self._lock: + if not self._is_replacing: + self._is_replacing = True + self._session.submit(self._replace, conn) + log.debug( + "Connection to host %s reached orphaned stream limit, replacing...", + self.host + ) + + start = time.time() + remaining = timeout + while True: + with conn.lock: + if not (conn.orphaned_threshold_reached and conn.is_closed) and conn.in_flight < conn.max_request_id: + conn.in_flight += 1 + return conn, conn.get_request_id() + if timeout is not None: + remaining = timeout - time.time() + start + if remaining < 0: + break + with self._stream_available_condition: + if conn.orphaned_threshold_reached and conn.is_closed: + conn = self._get_connection() + else: + self._stream_available_condition.wait(remaining) raise NoConnectionsAvailable("All request IDs are currently in use") - def return_connection(self, connection): - with connection.lock: - connection.in_flight -= 1 - - if (connection.is_defunct or connection.is_closed) and not connection.signaled_error: - log.debug("Defunct or closed connection (%s) returned to pool, potentially " - "marking host %s as down", id(connection), self.host) - is_down = self._session.cluster.signal_connection_failure( - self.host, connection.last_error, is_host_addition=False) - connection.signaled_error = True + def return_connection(self, connection, stream_was_orphaned=False): + if not stream_was_orphaned: + with connection.lock: + connection.in_flight -= 1 + with self._stream_available_condition: + self._stream_available_condition.notify() + + if connection.is_defunct or connection.is_closed: + if connection.signaled_error and not self.shutdown_on_error: + return + + is_down = False + if not connection.signaled_error: + log.debug("Defunct or closed connection (%s) returned to pool, potentially " + "marking host %s as down", id(connection), self.host) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) + connection.signaled_error = True + + if self.shutdown_on_error and not is_down: + is_down = True + self._session.cluster.on_down(self.host, is_host_addition=False) + if is_down: self.shutdown() else: @@ -332,15 +484,49 @@ def return_connection(self, connection): return self._is_replacing = True self._session.submit(self._replace, connection) + else: + if connection in self._trash: + with connection.lock: + if connection.in_flight == len(connection.orphaned_request_ids): + with self._lock: + if connection in self._trash: + self._trash.remove(connection) + log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) + connection.close() + return + + def on_orphaned_stream_released(self): + """ + Called when a response for an orphaned stream (timed out on the client + side) was received. + """ + with self._stream_available_condition: + self._stream_available_condition.notify() def _replace(self, connection): - log.debug("Replacing connection (%s) to %s", id(connection), self.host) - conn = self._session.cluster.connection_factory(self.host.address) - if self._session.keyspace: - conn.set_keyspace_blocking(self._session.keyspace) - self._connection = conn with self._lock: - self._is_replacing = False + if self.is_shutdown: + return + + log.debug("Replacing connection (%s) to %s", id(connection), self.host) + try: + conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + if self._keyspace: + conn.set_keyspace_blocking(self._keyspace) + self._connection = conn + except Exception: + log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) + self._session.submit(self._replace, connection) + else: + with connection.lock: + with self._lock: + if connection.orphaned_threshold_reached: + if connection.in_flight == len(connection.orphaned_request_ids): + connection.close() + else: + self._trash.add(connection) + self._is_replacing = False + self._stream_available_condition.notify() def shutdown(self): with self._lock: @@ -348,9 +534,21 @@ def shutdown(self): return else: self.is_shutdown = True + self._stream_available_condition.notify_all() if self._connection: self._connection.close() + self._connection = None + + trash_conns = None + with self._lock: + if self._trash: + trash_conns = self._trash + self._trash = set() + + if trash_conns is not None: + for conn in self._trash: + conn.close() def _set_keyspace_for_all_conns(self, keyspace, callback): if self.is_shutdown or not self._connection: @@ -361,6 +559,7 @@ def connection_finished_setting_keyspace(conn, error): errors = [] if not error else [error] callback(self, errors) + self._keyspace = keyspace self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) def get_connections(self): @@ -371,7 +570,9 @@ def get_state(self): connection = self._connection open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 in_flights = [connection.in_flight] if connection else [] - return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights} + orphan_requests = [connection.orphaned_request_ids] if connection else [] + return {'shutdown': self.is_shutdown, 'open_count': open_count, \ + 'in_flights': in_flights, 'orphan_requests': orphan_requests} @property def open_count(self): @@ -394,6 +595,7 @@ class HostConnectionPool(object): open_count = 0 _scheduled_for_creation = 0 _next_trash_allowed_at = 0 + _keyspace = None def __init__(self, host, host_distance, session): self.host = host @@ -405,12 +607,13 @@ def __init__(self, host, host_distance, session): log.debug("Initializing new connection pool for host %s", self.host) core_conns = session.cluster.get_core_connections_per_host(host_distance) - self._connections = [session.cluster.connection_factory(host.address) + self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) for i in range(core_conns)] - if session.keyspace: + self._keyspace = session.keyspace + if self._keyspace: for conn in self._connections: - conn.set_keyspace_blocking(session.keyspace) + conn.set_keyspace_blocking(self._keyspace) self._trash = set() self._next_trash_allowed_at = time.time() @@ -465,7 +668,7 @@ def borrow_connection(self, timeout): # wait_for_conn will increment in_flight on the conn least_busy, request_id = self._wait_for_conn(timeout) - # if we have too many requests on this connection but we still + # if we have too many requests on this connection, but we still # have space to open a new connection against this host, go ahead # and schedule the creation of a new connection if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: @@ -499,23 +702,23 @@ def _add_conn_if_under_max(self): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) with self._lock: if self.is_shutdown: - return False + return True if self.open_count >= max_conns: - return False + return True self.open_count += 1 log.debug("Going to open new connection to host %s", self.host) try: - conn = self._session.cluster.connection_factory(self.host.address) - if self._session.keyspace: + conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) + if self._keyspace: conn.set_keyspace_blocking(self._session.keyspace) self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] self._connections = new_connections - log.debug("Added new connection (%s) to pool for host %s, signaling availablility", + log.debug("Added new connection (%s) to pool for host %s, signaling availability", id(conn), self.host) self._signal_available_conn() return True @@ -568,9 +771,10 @@ def _wait_for_conn(self, timeout): raise NoConnectionsAvailable() - def return_connection(self, connection): + def return_connection(self, connection, stream_was_orphaned=False): with connection.lock: - connection.in_flight -= 1 + if not stream_was_orphaned: + connection.in_flight -= 1 in_flight = connection.in_flight if connection.is_defunct or connection.is_closed: @@ -606,6 +810,13 @@ def return_connection(self, connection): else: self._signal_available_conn() + def on_orphaned_stream_released(self): + """ + Called when a response for an orphaned stream (timed out on the client + side) was received. + """ + self._signal_available_conn() + def _maybe_trash_connection(self, connection): core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) did_trash = False @@ -646,17 +857,22 @@ def _replace(self, connection): if should_replace: log.debug("Replacing connection (%s) to %s", id(connection), self.host) - - def close_and_replace(): - connection.close() - self._add_conn_if_under_max() - - self._session.submit(close_and_replace) + connection.close() + self._session.submit(self._retrying_replace) else: - # just close it log.debug("Closing connection (%s) to %s", id(connection), self.host) connection.close() + def _retrying_replace(self): + replaced = False + try: + replaced = self._add_conn_if_under_max() + except Exception: + log.exception("Failed replacing connection to %s", self.host) + if not replaced: + log.debug("Failed replacing connection to %s. Retrying.", self.host) + self._session.submit(self._retrying_replace) + def shutdown(self): with self._lock: if self.is_shutdown: @@ -705,6 +921,7 @@ def connection_finished_setting_keyspace(conn, error): if not remaining_callbacks: callback(self, errors) + self._keyspace = keyspace for conn in self._connections: conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) @@ -713,4 +930,6 @@ def get_connections(self): def get_state(self): in_flights = [c.in_flight for c in self._connections] - return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights} + orphan_requests = [c.orphaned_request_ids for c in self._connections] + return {'shutdown': self.is_shutdown, 'open_count': self.open_count, \ + 'in_flights': in_flights, 'orphan_requests': orphan_requests} diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d2b9467348..510aea44a8 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,19 +20,15 @@ import socket from uuid import UUID -import six -from six.moves import range import io -from cassandra import type_codes +from cassandra import ProtocolVersion +from cassandra import type_codes, DriverException from cassandra import (Unavailable, WriteTimeout, ReadTimeout, WriteFailure, ReadFailure, FunctionFailure, AlreadyExists, InvalidRequest, Unauthorized, UnsupportedOperation, UserFunctionDescriptor, UserAggregateDescriptor, SchemaTargetType) -from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, - int8_pack, int8_unpack, uint64_pack, header_pack, - v3_header_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, @@ -38,8 +36,12 @@ LongType, MapType, SetType, TimeUUIDType, UTF8Type, VarcharType, UUIDType, UserType, TupleType, lookup_casstype, SimpleDateType, - TimeType, ByteType, ShortType) -from cassandra.policies import WriteType + TimeType, ByteType, ShortType, DurationType) +from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, + uint8_pack, int8_unpack, uint64_pack, header_pack, + v3_header_pack, uint32_pack, uint32_le_unpack, uint32_le_pack) +from cassandra.policies import ColDesc +from cassandra import WriteType from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from cassandra import util @@ -55,9 +57,6 @@ class InternalError(Exception): ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) -MIN_SUPPORTED_VERSION = 1 -MAX_SUPPORTED_VERSION = 4 - HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 @@ -65,6 +64,8 @@ class InternalError(Exception): TRACING_FLAG = 0x02 CUSTOM_PAYLOAD_FLAG = 0x04 WARNING_FLAG = 0x08 +USE_BETA_FLAG = 0x10 +USE_BETA_MASK = ~USE_BETA_FLAG _message_types_by_opcode = {} @@ -85,8 +86,7 @@ def __init__(cls, name, bases, dct): register_class(cls) -@six.add_metaclass(_RegisterMessageType) -class _MessageType(object): +class _MessageType(object, metaclass=_RegisterMessageType): tracing = False custom_payload = None @@ -126,24 +126,24 @@ def __init__(self, code, message, info): self.info = info @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, protocol_version, *args): code = read_int(f) msg = read_string(f) subcls = error_classes.get(code, cls) - extra_info = subcls.recv_error_info(f) + extra_info = subcls.recv_error_info(f, protocol_version) return subcls(code=code, message=msg, info=extra_info) def summary_msg(self): - msg = 'code=%04x [%s] message="%s"' \ + msg = 'Error from server: code=%04x [%s] message="%s"' \ % (self.code, self.summary, self.message) return msg def __str__(self): - return '' % self.summary_msg() + return '<%s>' % self.summary_msg() __repr__ = __str__ @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): pass def to_exception(self): @@ -156,8 +156,7 @@ def __init__(cls, name, bases, dct): error_classes[cls.error_code] = cls -@six.add_metaclass(ErrorMessageSubclass) -class ErrorMessageSub(ErrorMessage): +class ErrorMessageSub(ErrorMessage, metaclass=ErrorMessageSubclass): error_code = None @@ -178,6 +177,10 @@ class ProtocolException(ErrorMessageSub): summary = 'Protocol error' error_code = 0x000A + @property + def is_beta_protocol_error(self): + return 'USE_BETA flag is unset' in str(self) + class BadCredentials(ErrorMessageSub): summary = 'Bad credentials' @@ -189,7 +192,7 @@ class UnavailableErrorMessage(RequestExecutionException): error_code = 0x1000 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'required_replicas': read_int(f), @@ -220,7 +223,7 @@ class WriteTimeoutErrorMessage(RequestExecutionException): error_code = 0x1100 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), @@ -237,7 +240,7 @@ class ReadTimeoutErrorMessage(RequestExecutionException): error_code = 0x1200 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), @@ -254,13 +257,27 @@ class ReadFailureMessage(RequestExecutionException): error_code = 0x1300 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): + consistency = read_consistency_level(f) + received_responses = read_int(f) + required_responses = read_int(f) + + if ProtocolVersion.uses_error_code_map(protocol_version): + error_code_map = read_error_code_map(f) + failures = len(error_code_map) + else: + error_code_map = None + failures = read_int(f) + + data_retrieved = bool(read_byte(f)) + return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'failures': read_int(f), - 'data_retrieved': bool(read_byte(f)), + 'consistency': consistency, + 'received_responses': received_responses, + 'required_responses': required_responses, + 'failures': failures, + 'error_code_map': error_code_map, + 'data_retrieved': data_retrieved } def to_exception(self): @@ -272,7 +289,7 @@ class FunctionFailureMessage(RequestExecutionException): error_code = 0x1400 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): return { 'keyspace': read_string(f), 'function': read_string(f), @@ -288,19 +305,38 @@ class WriteFailureMessage(RequestExecutionException): error_code = 0x1500 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): + consistency = read_consistency_level(f) + received_responses = read_int(f) + required_responses = read_int(f) + + if ProtocolVersion.uses_error_code_map(protocol_version): + error_code_map = read_error_code_map(f) + failures = len(error_code_map) + else: + error_code_map = None + failures = read_int(f) + + write_type = WriteType.name_to_value[read_string(f)] + return { - 'consistency': read_consistency_level(f), - 'received_responses': read_int(f), - 'required_responses': read_int(f), - 'failures': read_int(f), - 'write_type': WriteType.name_to_value[read_string(f)], + 'consistency': consistency, + 'received_responses': received_responses, + 'required_responses': required_responses, + 'failures': failures, + 'error_code_map': error_code_map, + 'write_type': write_type } def to_exception(self): return WriteFailure(self.summary_msg(), **self.info) +class CDCWriteException(RequestExecutionException): + summary = 'Failed to execute write due to CDC space exhaustion.' + error_code = 0x1600 + + class SyntaxException(RequestValidationException): summary = 'Syntax error in CQL query' error_code = 0x2000 @@ -332,7 +368,7 @@ class PreparedQueryNotFound(RequestValidationException): error_code = 0x2500 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): # return the query ID return read_binary_string(f) @@ -342,7 +378,7 @@ class AlreadyExistsException(ConfigurationException): error_code = 0x2400 @staticmethod - def recv_error_info(f): + def recv_error_info(f, protocol_version): return { 'keyspace': read_string(f), 'table': read_string(f), @@ -352,6 +388,11 @@ def to_exception(self): return AlreadyExists(**self.info) +class ClientWriteError(RequestExecutionException): + summary = 'Client write failure.' + error_code = 0x8000 + + class StartupMessage(_MessageType): opcode = 0x01 name = 'STARTUP' @@ -359,6 +400,7 @@ class StartupMessage(_MessageType): KNOWN_OPTION_KEYS = set(( 'CQL_VERSION', 'COMPRESSION', + 'NO_COMPACT' )) def __init__(self, cqlversion, options): @@ -376,7 +418,7 @@ class ReadyMessage(_MessageType): name = 'READY' @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, *args): return cls() @@ -388,7 +430,7 @@ def __init__(self, authenticator): self.authenticator = authenticator @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): authname = read_string(f) return cls(authenticator=authname) @@ -420,7 +462,7 @@ def __init__(self, challenge): self.challenge = challenge @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): return cls(read_binary_longstring(f)) @@ -443,7 +485,7 @@ def __init__(self, token): self.token = token @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): return cls(read_longstring(f)) @@ -464,7 +506,7 @@ def __init__(self, cql_versions, options): self.options = options @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): options = read_stringmultimap(f) cql_versions = options.pop('CQL_VERSION') return cls(cql_versions=cql_versions, options=options) @@ -472,32 +514,38 @@ def recv_body(cls, f, protocol_version, user_type_map): # used for QueryMessage and ExecuteMessage _VALUES_FLAG = 0x01 -_SKIP_METADATA_FLAG = 0x01 +_SKIP_METADATA_FLAG = 0x02 _PAGE_SIZE_FLAG = 0x04 _WITH_PAGING_STATE_FLAG = 0x08 _WITH_SERIAL_CONSISTENCY_FLAG = 0x10 -_PROTOCOL_TIMESTAMP = 0x20 +_PROTOCOL_TIMESTAMP_FLAG = 0x20 +_NAMES_FOR_VALUES_FLAG = 0x40 # not used here +_WITH_KEYSPACE_FLAG = 0x80 +_PREPARED_WITH_KEYSPACE_FLAG = 0x01 +_PAGE_SIZE_BYTES_FLAG = 0x40000000 +_PAGING_OPTIONS_FLAG = 0x80000000 -class QueryMessage(_MessageType): - opcode = 0x07 - name = 'QUERY' +class _QueryMessage(_MessageType): - def __init__(self, query, consistency_level, serial_consistency_level=None, - fetch_size=None, paging_state=None, timestamp=None): - self.query = query + def __init__(self, query_params, consistency_level, + serial_consistency_level=None, fetch_size=None, + paging_state=None, timestamp=None, skip_meta=False, + continuous_paging_options=None, keyspace=None): + self.query_params = query_params self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.fetch_size = fetch_size self.paging_state = paging_state self.timestamp = timestamp - self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request. + self.skip_meta = skip_meta + self.continuous_paging_options = continuous_paging_options + self.keyspace = keyspace - def send_body(self, f, protocol_version): - write_longstring(f, self.query) + def _write_query_params(self, f, protocol_version): write_consistency_level(f, self.consistency_level) flags = 0x00 - if self._query_params is not None: + if self.query_params is not None: flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now if self.serial_consistency_level: @@ -526,15 +574,33 @@ def send_body(self, f, protocol_version): "2 or higher. Consider setting Cluster.protocol_version to 2.") if self.timestamp is not None: - flags |= _PROTOCOL_TIMESTAMP + flags |= _PROTOCOL_TIMESTAMP_FLAG - write_byte(f, flags) + if self.continuous_paging_options: + if ProtocolVersion.has_continuous_paging_support(protocol_version): + flags |= _PAGING_OPTIONS_FLAG + else: + raise UnsupportedOperation( + "Continuous paging may only be used with protocol version " + "ProtocolVersion.DSE_V1 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V1.") - if self._query_params is not None: - write_short(f, len(self._query_params)) - for param in self._query_params: - write_value(f, param) + if self.keyspace is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + else: + raise UnsupportedOperation( + "Keyspaces may only be set on queries with protocol version " + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_uint(f, flags) + else: + write_byte(f, flags) + if self.query_params is not None: + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) if self.fetch_size: write_int(f, self.fetch_size) if self.paging_state: @@ -543,6 +609,70 @@ def send_body(self, f, protocol_version): write_consistency_level(f, self.serial_consistency_level) if self.timestamp is not None: write_long(f, self.timestamp) + if self.keyspace is not None: + write_string(f, self.keyspace) + if self.continuous_paging_options: + self._write_paging_options(f, self.continuous_paging_options, protocol_version) + + def _write_paging_options(self, f, paging_options, protocol_version): + write_int(f, paging_options.max_pages) + write_int(f, paging_options.max_pages_per_second) + if ProtocolVersion.has_continuous_paging_next_pages(protocol_version): + write_int(f, paging_options.max_queue_size) + + +class QueryMessage(_QueryMessage): + opcode = 0x07 + name = 'QUERY' + + def __init__(self, query, consistency_level, serial_consistency_level=None, + fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + self.query = query + super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, + paging_state, timestamp, False, continuous_paging_options, keyspace) + + def send_body(self, f, protocol_version): + write_longstring(f, self.query) + self._write_query_params(f, protocol_version) + + +class ExecuteMessage(_QueryMessage): + opcode = 0x0A + name = 'EXECUTE' + + def __init__(self, query_id, query_params, consistency_level, + serial_consistency_level=None, fetch_size=None, + paging_state=None, timestamp=None, skip_meta=False, + continuous_paging_options=None, result_metadata_id=None): + self.query_id = query_id + self.result_metadata_id = result_metadata_id + super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, + paging_state, timestamp, skip_meta, continuous_paging_options) + + def _write_query_params(self, f, protocol_version): + if protocol_version == 1: + if self.serial_consistency_level: + raise UnsupportedOperation( + "Serial consistency levels require the use of protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2 " + "to support serial consistency levels.") + if self.fetch_size or self.paging_state: + raise UnsupportedOperation( + "Automatic query paging may only be used with protocol version " + "2 or higher. Consider setting Cluster.protocol_version to 2.") + write_short(f, len(self.query_params)) + for param in self.query_params: + write_value(f, param) + write_consistency_level(f, self.consistency_level) + else: + super(ExecuteMessage, self)._write_query_params(f, protocol_version) + + def send_body(self, f, protocol_version): + write_string(f, self.query_id) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + write_string(f, self.result_metadata_id) + self._write_query_params(f, protocol_version) + CUSTOM_TYPE = object() @@ -552,6 +682,7 @@ def send_body(self, f, protocol_version): RESULT_KIND_PREPARED = 0x0004 RESULT_KIND_SCHEMA_CHANGE = 0x0005 + class ResultMessage(_MessageType): opcode = 0x08 name = 'RESULT' @@ -566,61 +697,107 @@ class ResultMessage(_MessageType): _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _HAS_MORE_PAGES_FLAG = 0x0002 _NO_METADATA_FLAG = 0x0004 + _CONTINUOUS_PAGING_FLAG = 0x40000000 + _CONTINUOUS_PAGING_LAST_FLAG = 0x80000000 + _METADATA_ID_FLAG = 0x0008 - def __init__(self, kind, results, paging_state=None): + kind = None + + # These are all the things a result message might contain. They are populated according to 'kind' + column_names = None + column_types = None + parsed_rows = None + paging_state = None + continuous_paging_seq = None + continuous_paging_last = None + new_keyspace = None + column_metadata = None + query_id = None + bind_metadata = None + pk_indexes = None + schema_change_event = None + + def __init__(self, kind): self.kind = kind - self.results = results - self.paging_state = paging_state - @classmethod - def recv_body(cls, f, protocol_version, user_type_map): - kind = read_int(f) - paging_state = None - if kind == RESULT_KIND_VOID: - results = None - elif kind == RESULT_KIND_ROWS: - paging_state, results = cls.recv_results_rows( - f, protocol_version, user_type_map) - elif kind == RESULT_KIND_SET_KEYSPACE: - ksname = read_string(f) - results = ksname - elif kind == RESULT_KIND_PREPARED: - results = cls.recv_results_prepared(f, protocol_version, user_type_map) - elif kind == RESULT_KIND_SCHEMA_CHANGE: - results = cls.recv_results_schema_change(f, protocol_version) + def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + if self.kind == RESULT_KIND_VOID: + return + elif self.kind == RESULT_KIND_ROWS: + self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + elif self.kind == RESULT_KIND_SET_KEYSPACE: + self.new_keyspace = read_string(f) + elif self.kind == RESULT_KIND_PREPARED: + self.recv_results_prepared(f, protocol_version, user_type_map) + elif self.kind == RESULT_KIND_SCHEMA_CHANGE: + self.recv_results_schema_change(f, protocol_version) else: - raise Exception("Unknown RESULT kind: %d" % kind) - return cls(kind, results, paging_state) + raise DriverException("Unknown RESULT kind: %d" % self.kind) @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): - paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + def recv_body(cls, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + kind = read_int(f) + msg = cls(kind) + msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + return msg + + def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + self.recv_results_metadata(f, user_type_map) + column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) - rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] - colnames = [c[2] for c in column_metadata] - coltypes = [c[3] for c in column_metadata] - parsed_rows = [ - tuple(ctype.from_binary(val, protocol_version) - for ctype, val in zip(coltypes, row)) - for row in rows] - return (paging_state, (colnames, parsed_rows)) + rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + self.column_names = [c[2] for c in column_metadata] + self.column_types = [c[3] for c in column_metadata] + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - @classmethod - def recv_results_prepared(cls, f, protocol_version, user_type_map): - query_id = read_binary_string(f) - column_metadata, pk_indexes = cls.recv_prepared_metadata(f, protocol_version, user_type_map) - return (query_id, column_metadata, pk_indexes) + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) + col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] + raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val + return col_type.from_binary(raw_bytes, protocol_version) - @classmethod - def recv_results_metadata(cls, f, user_type_map): + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + + try: + self.parsed_rows = [decode_row(row) for row in rows] + except Exception: + for row in rows: + for val, col_md, col_desc in zip(row, column_metadata, col_descs): + try: + decode_val(val, col_md, col_desc) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], + col_md[3].cql_parameterized_type(), + str(e))) + + def recv_results_prepared(self, f, protocol_version, user_type_map): + self.query_id = read_binary_string(f) + if ProtocolVersion.uses_prepared_metadata(protocol_version): + self.result_metadata_id = read_binary_string(f) + else: + self.result_metadata_id = None + self.recv_prepared_metadata(f, protocol_version, user_type_map) + + def recv_results_metadata(self, f, user_type_map): flags = read_int(f) - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) - if flags & cls._HAS_MORE_PAGES_FLAG: - paging_state = read_binary_longstring(f) - else: - paging_state = None + if flags & self._HAS_MORE_PAGES_FLAG: + self.paging_state = read_binary_longstring(f) + + no_meta = bool(flags & self._NO_METADATA_FLAG) + if no_meta: + return + + if flags & self._CONTINUOUS_PAGING_FLAG: + self.continuous_paging_seq = read_int(f) + self.continuous_paging_last = flags & self._CONTINUOUS_PAGING_LAST_FLAG + + if flags & self._METADATA_ID_FLAG: + self.result_metadata_id = read_binary_string(f) + + glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) @@ -633,24 +810,24 @@ def recv_results_metadata(cls, f, user_type_map): colksname = read_string(f) colcfname = read_string(f) colname = read_string(f) - coltype = cls.read_type(f, user_type_map) + coltype = self.read_type(f, user_type_map) column_metadata.append((colksname, colcfname, colname, coltype)) - return paging_state, column_metadata - @classmethod - def recv_prepared_metadata(cls, f, protocol_version, user_type_map): + self.column_metadata = column_metadata + + def recv_prepared_metadata(self, f, protocol_version, user_type_map): flags = read_int(f) - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) pk_indexes = None if protocol_version >= 4: num_pk_indexes = read_int(f) pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] + glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) - column_metadata = [] + bind_metadata = [] for _ in range(colcount): if glob_tblspec: colksname = ksname @@ -659,13 +836,17 @@ def recv_prepared_metadata(cls, f, protocol_version, user_type_map): colksname = read_string(f) colcfname = read_string(f) colname = read_string(f) - coltype = cls.read_type(f, user_type_map) - column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) - return column_metadata, pk_indexes + coltype = self.read_type(f, user_type_map) + bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) - @classmethod - def recv_results_schema_change(cls, f, protocol_version): - return EventMessage.recv_schema_change(f, protocol_version) + if protocol_version >= 2: + self.recv_results_metadata(f, user_type_map) + + self.bind_metadata = bind_metadata + self.pk_indexes = pk_indexes + + def recv_results_schema_change(self, f, protocol_version): + self.schema_change_event = EventMessage.recv_schema_change(f, protocol_version) @classmethod def read_type(cls, f, user_type_map): @@ -710,72 +891,38 @@ class PrepareMessage(_MessageType): opcode = 0x09 name = 'PREPARE' - def __init__(self, query): + def __init__(self, query, keyspace=None): self.query = query + self.keyspace = keyspace def send_body(self, f, protocol_version): write_longstring(f, self.query) + flags = 0x00 -class ExecuteMessage(_MessageType): - opcode = 0x0A - name = 'EXECUTE' - - def __init__(self, query_id, query_params, consistency_level, - serial_consistency_level=None, fetch_size=None, - paging_state=None, timestamp=None): - self.query_id = query_id - self.query_params = query_params - self.consistency_level = consistency_level - self.serial_consistency_level = serial_consistency_level - self.fetch_size = fetch_size - self.paging_state = paging_state - self.timestamp = timestamp - - def send_body(self, f, protocol_version): - write_string(f, self.query_id) - if protocol_version == 1: - if self.serial_consistency_level: - raise UnsupportedOperation( - "Serial consistency levels require the use of protocol version " - "2 or higher. Consider setting Cluster.protocol_version to 2 " - "to support serial consistency levels.") - if self.fetch_size or self.paging_state: + if self.keyspace is not None: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _PREPARED_WITH_KEYSPACE_FLAG + else: raise UnsupportedOperation( - "Automatic query paging may only be used with protocol version " - "2 or higher. Consider setting Cluster.protocol_version to 2.") - write_short(f, len(self.query_params)) - for param in self.query_params: - write_value(f, param) - write_consistency_level(f, self.consistency_level) + "Keyspaces may only be set on queries with protocol version " + "5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.") + + if ProtocolVersion.uses_prepare_flags(protocol_version): + write_uint(f, flags) else: - write_consistency_level(f, self.consistency_level) - flags = _VALUES_FLAG - if self.serial_consistency_level: - flags |= _WITH_SERIAL_CONSISTENCY_FLAG - if self.fetch_size: - flags |= _PAGE_SIZE_FLAG - if self.paging_state: - flags |= _WITH_PAGING_STATE_FLAG - if self.timestamp is not None: - if protocol_version >= 3: - flags |= _PROTOCOL_TIMESTAMP - else: - raise UnsupportedOperation( - "Protocol-level timestamps may only be used with protocol version " - "3 or higher. Consider setting Cluster.protocol_version to 3.") - write_byte(f, flags) - write_short(f, len(self.query_params)) - for param in self.query_params: - write_value(f, param) - if self.fetch_size: - write_int(f, self.fetch_size) - if self.paging_state: - write_longstring(f, self.paging_state) - if self.serial_consistency_level: - write_consistency_level(f, self.serial_consistency_level) - if self.timestamp is not None: - write_long(f, self.timestamp) + # checks above should prevent this, but just to be safe... + if flags: + raise UnsupportedOperation( + "Attempted to set flags with value {flags:0=#8x} on" + "protocol version {pv}, which doesn't support flags" + "in prepared statements." + "Consider setting Cluster.protocol_version to 5 or DSE_V2." + "".format(flags=flags, pv=protocol_version)) + + if ProtocolVersion.uses_keyspace_flag(protocol_version): + if self.keyspace: + write_string(f, self.keyspace) class BatchMessage(_MessageType): @@ -783,12 +930,14 @@ class BatchMessage(_MessageType): name = 'BATCH' def __init__(self, batch_type, queries, consistency_level, - serial_consistency_level=None, timestamp=None): + serial_consistency_level=None, timestamp=None, + keyspace=None): self.batch_type = batch_type self.queries = queries self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.timestamp = timestamp + self.keyspace = keyspace def send_body(self, f, protocol_version): write_byte(f, self.batch_type.value) @@ -811,14 +960,29 @@ def send_body(self, f, protocol_version): if self.serial_consistency_level: flags |= _WITH_SERIAL_CONSISTENCY_FLAG if self.timestamp is not None: - flags |= _PROTOCOL_TIMESTAMP - write_byte(f, flags) + flags |= _PROTOCOL_TIMESTAMP_FLAG + if self.keyspace: + if ProtocolVersion.uses_keyspace_flag(protocol_version): + flags |= _WITH_KEYSPACE_FLAG + else: + raise UnsupportedOperation( + "Keyspaces may only be set on queries with protocol version " + "5 or higher. Consider setting Cluster.protocol_version to 5.") + + if ProtocolVersion.uses_int_query_flags(protocol_version): + write_int(f, flags) + else: + write_byte(f, flags) if self.serial_consistency_level: write_consistency_level(f, self.serial_consistency_level) if self.timestamp is not None: write_long(f, self.timestamp) + if ProtocolVersion.uses_keyspace_flag(protocol_version): + if self.keyspace is not None: + write_string(f, self.keyspace) + known_event_types = frozenset(( 'TOPOLOGY_CHANGE', @@ -847,7 +1011,7 @@ def __init__(self, event_type, event_args): self.event_args = event_args @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) @@ -894,6 +1058,34 @@ def recv_schema_change(cls, f, protocol_version): return event +class ReviseRequestMessage(_MessageType): + + class RevisionType(object): + PAGING_CANCEL = 1 + PAGING_BACKPRESSURE = 2 + + opcode = 0xFF + name = 'REVISE_REQUEST' + + def __init__(self, op_type, op_id, next_pages=0): + self.op_type = op_type + self.op_id = op_id + self.next_pages = next_pages + + def send_body(self, f, protocol_version): + write_int(f, self.op_type) + write_int(f, self.op_id) + if self.op_type == ReviseRequestMessage.RevisionType.PAGING_BACKPRESSURE: + if self.next_pages <= 0: + raise UnsupportedOperation("Continuous paging backpressure requires next_pages > 0") + elif not ProtocolVersion.has_continuous_paging_next_pages(protocol_version): + raise UnsupportedOperation( + "Continuous paging backpressure may only be used with protocol version " + "ProtocolVersion.DSE_V2 or higher. Consider setting Cluster.protocol_version to ProtocolVersion.DSE_V2.") + else: + write_int(f, self.next_pages) + + class _ProtocolHandler(object): """ _ProtocolHander handles encoding and decoding messages. @@ -912,8 +1104,11 @@ class _ProtocolHandler(object): result decoding implementations. """ + column_encryption_policy = None + """Instance of :class:`cassandra.policies.ColumnEncryptionPolicy` in use by this handler""" + @classmethod - def encode_message(cls, msg, stream_id, protocol_version, compressor): + def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): """ Encodes a message using the specified frame parameters, and compressor @@ -932,13 +1127,18 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor): msg.send_body(body, protocol_version) body = body.getvalue() - if compressor and len(body) > 0: + # With checksumming, the compression is done at the segment frame encoding + if (not ProtocolVersion.has_checksumming_support(protocol_version) + and compressor and len(body) > 0): body = compressor(body) flags |= COMPRESSED_FLAG if msg.tracing: flags |= TRACING_FLAG + if allow_beta_protocol_version: + flags |= USE_BETA_FLAG + buff = io.BytesIO() cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) buff.write(body) @@ -956,7 +1156,7 @@ def _write_header(f, version, flags, stream_id, opcode, length): @classmethod def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, - decompressor): + decompressor, result_metadata): """ Decodes a native protocol message body @@ -969,9 +1169,10 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod :param decompressor: optional decompression function to inflate the body :return: a message decoded from the body and frame attributes """ - if flags & COMPRESSED_FLAG: + if (not ProtocolVersion.has_checksumming_support(protocol_version) and + flags & COMPRESSED_FLAG): if decompressor is None: - raise Exception("No de-compressor available for compressed frame!") + raise RuntimeError("No de-compressor available for compressed frame!") body = decompressor(body) flags ^= COMPRESSED_FLAG @@ -994,11 +1195,13 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod else: custom_payload = None + flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment + if flags: log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) msg_class = cls.message_types_by_opcode[opcode] - msg = msg_class.recv_body(body, protocol_version, user_type_map) + msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata, cls.column_encryption_policy) msg.stream_id = stream_id msg.trace_id = trace_id msg.custom_payload = custom_payload @@ -1010,6 +1213,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod return msg + def cython_protocol_handler(colparser): """ Given a column parser to deserialize ResultMessages, return a suitable @@ -1037,7 +1241,7 @@ class FastResultMessage(ResultMessage): """ # type_codes = ResultMessage.type_codes.copy() code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) - recv_results_rows = classmethod(make_recv_results_rows(colparser)) + recv_results_rows = make_recv_results_rows(colparser) class CythonProtocolHandler(_ProtocolHandler): """ @@ -1075,17 +1279,48 @@ def read_byte(f): def write_byte(f, b): - f.write(int8_pack(b)) + f.write(uint8_pack(b)) def read_int(f): return int32_unpack(f.read(4)) +def read_uint_le(f, size=4): + """ + Read a sequence of little endian bytes and return an unsigned integer. + """ + + if size == 4: + value = uint32_le_unpack(f.read(4)) + else: + value = 0 + for i in range(size): + value |= (read_byte(f) & 0xFF) << 8 * i + + return value + + +def write_uint_le(f, i, size=4): + """ + Write an unsigned integer on a sequence of little endian bytes. + """ + if size == 4: + f.write(uint32_le_pack(i)) + else: + for j in range(size): + shift = j * 8 + write_byte(f, i >> shift & 0xFF) + + def write_int(f, i): f.write(int32_pack(i)) +def write_uint(f, i): + f.write(uint32_pack(i)) + + def write_long(f, i): f.write(uint64_pack(i)) @@ -1119,7 +1354,7 @@ def read_binary_string(f): def write_string(f, s): - if isinstance(s, six.text_type): + if isinstance(s, str): s = s.encode('utf8') write_short(f, len(s)) f.write(s) @@ -1136,7 +1371,7 @@ def read_longstring(f): def write_longstring(f, s): - if isinstance(s, six.text_type): + if isinstance(s, str): s = s.encode('utf8') write_int(f, len(s)) f.write(s) @@ -1201,6 +1436,15 @@ def write_stringmultimap(f, strmmap): write_stringlist(f, v) +def read_error_code_map(f): + numpairs = read_int(f) + error_code_map = {} + for _ in range(numpairs): + endpoint = read_inet_addr_only(f) + error_code_map[endpoint] = read_short(f) + return error_code_map + + def read_value(f): size = read_int(f) if size < 0: @@ -1218,17 +1462,22 @@ def write_value(f, v): f.write(v) -def read_inet(f): +def read_inet_addr_only(f): size = read_byte(f) addrbytes = f.read(size) - port = read_int(f) if size == 4: addrfam = socket.AF_INET elif size == 16: addrfam = socket.AF_INET6 else: raise InternalError("bad inet address: %r" % (addrbytes,)) - return (util.inet_ntop(addrfam, addrbytes), port) + return util.inet_ntop(addrfam, addrbytes) + + +def read_inet(f): + addr = read_inet_addr_only(f) + port = read_int(f) + return (addr, port) def write_inet(f, addrtuple): diff --git a/cassandra/query.py b/cassandra/query.py index 01bf726979..40e4d63c9e 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,19 +21,19 @@ """ from collections import namedtuple -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import re import struct import time -import six -from six.moves import range, zip +import warnings from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.util import unix_time_from_uuid1 from cassandra.encoder import Encoder import cassandra.encoder +from cassandra.policies import ColDesc from cassandra.protocol import _UNSET_VALUE -from cassandra.util import OrderedDict, _positional_rename_invalid_identifiers +from cassandra.util import OrderedDict, _sanitize_identifiers import logging log = logging.getLogger(__name__) @@ -75,7 +77,7 @@ def tuple_factory(colnames, rows): >>> session = cluster.connect('mykeyspace') >>> session.row_factory = tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") - >>> print rows[0] + >>> print(rows[0]) ('Bob', 42) .. versionchanged:: 2.0.0 @@ -83,6 +85,39 @@ def tuple_factory(colnames, rows): """ return rows +class PseudoNamedTupleRow(object): + """ + Helper class for pseudo_named_tuple_factory. These objects provide an + __iter__ interface, as well as index- and attribute-based access to values, + but otherwise do not attempt to implement the full namedtuple or iterable + interface. + """ + def __init__(self, ordered_dict): + self._dict = ordered_dict + self._tuple = tuple(ordered_dict.values()) + + def __getattr__(self, name): + return self._dict[name] + + def __getitem__(self, idx): + return self._tuple[idx] + + def __iter__(self): + return iter(self._tuple) + + def __repr__(self): + return '{t}({od})'.format(t=self.__class__.__name__, + od=self._dict) + + +def pseudo_namedtuple_factory(colnames, rows): + """ + Returns each row as a :class:`.PseudoNamedTupleRow`. This is the fallback + factory for cases where :meth:`.named_tuple_factory` fails to create rows. + """ + return [PseudoNamedTupleRow(od) + for od in ordered_dict_factory(colnames, rows)] + def named_tuple_factory(colnames, rows): """ @@ -98,16 +133,16 @@ def named_tuple_factory(colnames, rows): >>> user = rows[0] >>> # you can access field by their name: - >>> print "name: %s, age: %d" % (user.name, user.age) + >>> print("name: %s, age: %d" % (user.name, user.age)) name: Bob, age: 42 >>> # or you can access fields by their position (like a tuple) >>> name, age = user - >>> print "name: %s, age: %d" % (name, age) + >>> print("name: %s, age: %d" % (name, age)) name: Bob, age: 42 >>> name = user[0] >>> age = user[1] - >>> print "name: %s, age: %d" % (name, age) + >>> print("name: %s, age: %d" % (name, age)) name: Bob, age: 42 .. versionchanged:: 2.0.0 @@ -116,14 +151,29 @@ def named_tuple_factory(colnames, rows): clean_column_names = map(_clean_column_name, colnames) try: Row = namedtuple('Row', clean_column_names) + except SyntaxError: + warnings.warn( + "Failed creating namedtuple for a result because there were too " + "many columns. This is due to a Python limitation that affects " + "namedtuple in Python 3.0-3.6 (see issue18896). The row will be " + "created with {substitute_factory_name}, which lacks some namedtuple " + "features and is slower. To avoid slower performance accessing " + "values on row objects, Upgrade to Python 3.7, or use a different " + "row factory. (column names: {colnames})".format( + substitute_factory_name=pseudo_namedtuple_factory.__name__, + colnames=colnames + ) + ) + return pseudo_namedtuple_factory(colnames, rows) except Exception: + clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " "(see Python 'namedtuple' documentation for details on name rules). " "Results will be returned with positional names. " "Avoid this by choosing different names, using SELECT \"\" AS aliases, " "or specifying a different row_factory on your Session" % (colnames, clean_column_names)) - Row = namedtuple('Row', _positional_rename_invalid_identifiers(clean_column_names)) + Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] @@ -138,7 +188,7 @@ def dict_factory(colnames, rows): >>> session = cluster.connect('mykeyspace') >>> session.row_factory = dict_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") - >>> print rows[0] + >>> print(rows[0]) {u'age': 42, u'name': u'Bob'} .. versionchanged:: 2.0.0 @@ -196,8 +246,7 @@ class Statement(object): keyspace = None """ The string name of the keyspace this query acts on. This is used when - :class:`~.TokenAwarePolicy` is configured for - :attr:`.Cluster.load_balancing_policy` + :class:`~.TokenAwarePolicy` is configured in the profile load balancing policy. It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, but must be set explicitly on :class:`.SimpleStatement`. @@ -214,13 +263,21 @@ class Statement(object): .. versionadded:: 2.6.0 """ + is_idempotent = False + """ + Flag indicating whether this statement is safe to run multiple times in speculative execution. + """ + _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): - self.retry_policy = retry_policy + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, + is_idempotent=False): + if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors + raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') + if retry_policy is not None: + self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level self._routing_key = routing_key @@ -232,14 +289,22 @@ def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, self.keyspace = keyspace if custom_payload is not None: self.custom_payload = custom_payload + self.is_idempotent = is_idempotent + + def _key_parts_packed(self, parts): + for p in parts: + l = len(p) + yield struct.pack(">H%dsB" % l, l, p, 0) def _get_routing_key(self): return self._routing_key def _set_routing_key(self, key): if isinstance(key, (list, tuple)): - self._routing_key = b"".join(struct.pack("HsB", len(component), component, 0) - for component in key) + if len(key) == 1: + self._routing_key = key[0] + else: + self._routing_key = b"".join(self._key_parts_packed(key)) else: self._routing_key = key @@ -263,8 +328,8 @@ def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): - acceptable = (None, ConsistencyLevel.SERIAL, ConsistencyLevel.LOCAL_SERIAL) - if serial_consistency_level not in acceptable: + if (serial_consistency_level is not None and + not ConsistencyLevel.is_serial(serial_consistency_level)): raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " "or ConsistencyLevel.LOCAL_SERIAL") @@ -317,15 +382,18 @@ class SimpleStatement(Statement): A simple, un-prepared query. """ - def __init__(self, query_string, *args, **kwargs): + def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. - All arguments to :class:`Statement` apply to this class as well + See :class:`Statement` attributes for a description of the other parameters. """ - Statement.__init__(self, *args, **kwargs) + Statement.__init__(self, retry_policy, consistency_level, routing_key, + serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property @@ -347,38 +415,61 @@ class PreparedStatement(object): A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement may affect performance (as the operation requires a network roundtrip). + + |prepared_stmt_head|: Do not use ``*`` in prepared statements if you might + change the schema of the table being queried. The driver and server each + maintain a map between metadata for a schema and statements that were + prepared against that schema. When a user changes a schema, e.g. by adding + or removing a column, the server invalidates its mappings involving that + schema. However, there is currently no way to propagate that invalidation + to drivers. Thus, after a schema change, the driver will incorrectly + interpret the results of ``SELECT *`` queries prepared before the schema + change. This is currently being addressed in `CASSANDRA-10786 + `_. + + .. |prepared_stmt_head| raw:: html + + A note about * in prepared statements """ - column_metadata = None + column_metadata = None #TODO: make this bind_metadata in next major + retry_policy = None + consistency_level = None + custom_payload = None + fetch_size = FETCH_SIZE_UNSET + keyspace = None # change to prepared_keyspace in major release + protocol_version = None query_id = None query_string = None - keyspace = None # change to prepared_keyspace in major release - + result_metadata = None + result_metadata_id = None + column_encryption_policy = None routing_key_indexes = None _routing_key_index_set = None - - consistency_level = None - serial_consistency_level = None - - protocol_version = None - - fetch_size = FETCH_SIZE_UNSET - - custom_payload = None + serial_consistency_level = None # TODO never used? def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version): + keyspace, protocol_version, result_metadata, result_metadata_id, + column_encryption_policy=None): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version + self.result_metadata = result_metadata + self.result_metadata_id = result_metadata_id + self.column_encryption_policy = column_encryption_policy + self.is_idempotent = False @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version): + def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy=None): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version) + return PreparedStatement(column_metadata, query_id, None, + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy) if pk_indexes: routing_key_indexes = pk_indexes @@ -403,7 +494,8 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, q pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version) + query, prepared_keyspace, protocol_version, result_metadata, + result_metadata_id, column_encryption_policy) def bind(self, values): """ @@ -441,25 +533,31 @@ class BoundStatement(Statement): The sequence of values that were bound to the prepared statement. """ - def __init__(self, prepared_statement, *args, **kwargs): + def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. - All arguments to :class:`Statement` apply to this class as well + See :class:`Statement` attributes for a description of the other parameters. """ self.prepared_statement = prepared_statement + self.retry_policy = prepared_statement.retry_policy self.consistency_level = prepared_statement.consistency_level self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size self.custom_payload = prepared_statement.custom_payload + self.is_idempotent = prepared_statement.is_idempotent self.values = [] meta = prepared_statement.column_metadata if meta: self.keyspace = meta[0].keyspace_name - Statement.__init__(self, *args, **kwargs) + Statement.__init__(self, retry_policy, consistency_level, routing_key, + serial_consistency_level, fetch_size, keyspace, custom_payload, + prepared_statement.is_idempotent) def bind(self, values): """ @@ -485,6 +583,7 @@ def bind(self, values): values = () proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata + ce_policy = self.prepared_statement.column_encryption_policy # special case for binding dicts if isinstance(values, dict): @@ -531,7 +630,13 @@ def bind(self, values): raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: try: - self.values.append(col_spec.type.serialize(value, proto_version)) + col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) + uses_ce = ce_policy and ce_policy.contains_column(col_desc) + col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type + col_bytes = col_type.serialize(value, proto_version) + if uses_ce: + col_bytes = ce_policy.encrypt(col_desc, col_bytes) + self.values.append(col_bytes) except (TypeError, struct.error) as exc: actual_type = type(value) message = ('Received an argument of invalid type for column "%s". ' @@ -565,13 +670,7 @@ def routing_key(self): if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: - components = [] - for statement_index in routing_indexes: - val = self.values[statement_index] - l = len(val) - components.append(struct.pack(">H%dsB" % l, l, val, 0)) - - self._routing_key = b"".join(components) + self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) return self._routing_key @@ -697,6 +796,19 @@ def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) + def clear(self): + """ + This is a convenience method to clear a batch statement for reuse. + + *Note:* it should not be used concurrently with uncompleted execution futures executing the same + ``BatchStatement``. + """ + del self._statements_and_parameters[:] + self.keyspace = None + self.routing_key = None + if self.custom_payload: + self.custom_payload.clear() + def add(self, statement, parameters=None): """ Adds a :class:`.Statement` and optional sequence of parameters @@ -705,25 +817,23 @@ def add(self, statement, parameters=None): Like with other statements, parameters must be a sequence, even if there is only one item. """ - if isinstance(statement, six.string_types): + if isinstance(statement, str): if parameters: encoder = Encoder() if self._session is None else self._session.encoder statement = bind_params(statement, parameters, encoder) - self._statements_and_parameters.append((False, statement, ())) + self._add_statement_and_params(False, statement, ()) elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) - self._statements_and_parameters.append( - (True, query_id, bound_statement.values)) + self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) - self._statements_and_parameters.append( - (True, statement.prepared_statement.query_id, statement.values)) + self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement query_string = statement.query_string @@ -731,17 +841,22 @@ def add(self, statement, parameters=None): encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) - self._statements_and_parameters.append((False, query_string, ())) + self._add_statement_and_params(False, query_string, ()) return self def add_all(self, statements, parameters): """ Adds a sequence of :class:`.Statement` objects and a matching sequence - of parameters to the batch. :const:`None` can be used in place of - parameters when no parameters are needed. + of parameters to the batch. Statement and parameter sequences must be of equal length or + one will be truncated. :const:`None` can be used in the parameters position where are needed. """ for statement, value in zip(statements, parameters): - self.add(statement, parameters) + self.add(statement, value) + + def _add_statement_and_params(self, is_prepared, statement, parameters): + if len(self._statements_and_parameters) >= 0xFFFF: + raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) + self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): if self.routing_key is None: @@ -759,10 +874,13 @@ def _update_state(self, statement): self._maybe_set_routing_attributes(statement) self._update_custom_payload(statement) + def __len__(self): + return len(self._statements_and_parameters) + def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % - (self.batch_type, len(self._statements_and_parameters), consistency)) + (self.batch_type, len(self), consistency)) __repr__ = __str__ @@ -783,10 +901,8 @@ def __str__(self): def bind_params(query, params, encoder): - if six.PY2 and isinstance(query, six.text_type): - query = query.encode('utf-8') if isinstance(params, dict): - return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params)) + return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in params.items()) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) @@ -863,7 +979,7 @@ def __init__(self, trace_id, session): self.trace_id = trace_id self._session = session - def populate(self, max_wait=2.0, wait_for_complete=True): + def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored @@ -874,6 +990,9 @@ def populate(self, max_wait=2.0, wait_for_complete=True): `wait_for_complete=False` bypasses the wait for duration to be populated. This can be used to query events from partial sessions. + + `query_cl` specifies a consistency level to use for polling the trace tables, + if different from the session default. """ attempt = 0 start = time.time() @@ -885,9 +1004,11 @@ def populate(self, max_wait=2.0, wait_for_complete=True): log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) session_results = self._execute( - self._SELECT_SESSIONS_FORMAT, (self.trace_id,), time_spent, max_wait) + SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) - is_complete = session_results and session_results[0].duration is not None + # PYTHON-730: There is race condition that the duration mutation is written before started_at the for fast queries + session_row = session_results.one() if session_results else None + is_complete = session_row is not None and session_row.duration is not None and session_row.started_at is not None if not session_results or (wait_for_complete and not is_complete): time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 @@ -895,9 +1016,8 @@ def populate(self, max_wait=2.0, wait_for_complete=True): if is_complete: log.debug("Fetched trace info for trace ID: %s", self.trace_id) else: - log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) + log.debug("Fetching partial trace info for trace ID: %s", self.trace_id) - session_row = session_results[0] self.request_type = session_row.request self.duration = timedelta(microseconds=session_row.duration) if is_complete else None self.started_at = session_row.started_at @@ -909,7 +1029,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True): log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) time_spent = time.time() - start event_results = self._execute( - self._SELECT_EVENTS_FORMAT, (self.trace_id,), time_spent, max_wait) + SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) @@ -967,7 +1087,7 @@ class TraceEvent(object): def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description - self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) + self.datetime = datetime.fromtimestamp(unix_time_from_uuid1(timeuuid), tz=timezone.utc) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) @@ -977,3 +1097,17 @@ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): def __str__(self): return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) + + +# TODO remove next major since we can target using the `host` attribute of session.execute +class HostTargetingStatement(object): + """ + Wraps any query statement and attaches a target host, making + it usable in a targeted LBP without modifying the user's statement. + """ + def __init__(self, inner_statement, target_host): + self.__class__ = type(inner_statement.__class__.__name__, + (self.__class__, inner_statement.__class__), + {}) + self.__dict__ = inner_statement.__dict__ + self.target_host = target_host diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index ec2b83bed7..d172f1bcaf 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,26 +15,38 @@ # limitations under the License. from cassandra.parsing cimport ParseDesc, ColumnParser +from cassandra.policies import ColDesc +from cassandra.obj_parser import TupleRowParser from cassandra.deserializers import make_deserializers include "ioutils.pyx" def make_recv_results_rows(ColumnParser colparser): - def recv_results_rows(cls, f, int protocol_version, user_type_map): + def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata, column_encryption_policy): """ Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) This is used as the recv_results_rows method of (Fast)ResultMessage """ - paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + self.recv_results_metadata(f, user_type_map) - colnames = [c[2] for c in column_metadata] - coltypes = [c[3] for c in column_metadata] + column_metadata = self.column_metadata or result_metadata - desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes), - protocol_version) - reader = BytesIOReader(f.read()) - parsed_rows = colparser.parse_rows(reader, desc) + self.column_names = [md[2] for md in column_metadata] + self.column_types = [md[3] for md in column_metadata] - return (paging_state, (colnames, parsed_rows)) + desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(self.column_types), protocol_version) + reader = BytesIOReader(f.read()) + try: + self.parsed_rows = colparser.parse_rows(reader, desc) + except Exception as e: + # Use explicitly the TupleRowParser to display better error messages for column decoding failures + rowparser = TupleRowParser() + reader.buf_ptr = reader.buf + reader.pos = 0 + rowcount = read_int(reader) + for i in range(rowcount): + rowparser.unpack_row(reader, desc) return recv_results_rows diff --git a/cassandra/segment.py b/cassandra/segment.py new file mode 100644 index 0000000000..2d7a107566 --- /dev/null +++ b/cassandra/segment.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zlib + +from cassandra import DriverException +from cassandra.marshal import int32_pack +from cassandra.protocol import write_uint_le, read_uint_le + +CRC24_INIT = 0x875060 +CRC24_POLY = 0x1974F0B +CRC24_LENGTH = 3 +CRC32_LENGTH = 4 +CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca") + + +class CrcException(Exception): + """ + CRC mismatch error. + + TODO: here to avoid import cycles with cassandra.connection. In the next + major, the exceptions should be declared in a separated exceptions.py + file. + """ + pass + + +def compute_crc24(data, length): + crc = CRC24_INIT + + for _ in range(length): + crc ^= (data & 0xff) << 16 + data >>= 8 + + for i in range(8): + crc <<= 1 + if crc & 0x1000000 != 0: + crc ^= CRC24_POLY + + return crc + + +def compute_crc32(data, value): + crc32 = zlib.crc32(data, value) + return crc32 + + +class SegmentHeader(object): + + payload_length = None + uncompressed_payload_length = None + is_self_contained = None + + def __init__(self, payload_length, uncompressed_payload_length, is_self_contained): + self.payload_length = payload_length + self.uncompressed_payload_length = uncompressed_payload_length + self.is_self_contained = is_self_contained + + @property + def segment_length(self): + """ + Return the total length of the segment, including the CRC. + """ + hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \ + else SegmentCodec.COMPRESSED_HEADER_LENGTH + return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH + + +class Segment(object): + + MAX_PAYLOAD_LENGTH = 128 * 1024 - 1 + + payload = None + is_self_contained = None + + def __init__(self, payload, is_self_contained): + self.payload = payload + self.is_self_contained = is_self_contained + + +class SegmentCodec(object): + + COMPRESSED_HEADER_LENGTH = 5 + UNCOMPRESSED_HEADER_LENGTH = 3 + FLAG_OFFSET = 17 + + compressor = None + decompressor = None + + def __init__(self, compressor=None, decompressor=None): + self.compressor = compressor + self.decompressor = decompressor + + @property + def header_length(self): + return self.COMPRESSED_HEADER_LENGTH if self.compression \ + else self.UNCOMPRESSED_HEADER_LENGTH + + @property + def header_length_with_crc(self): + return (self.COMPRESSED_HEADER_LENGTH if self.compression + else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH + + @property + def compression(self): + return self.compressor and self.decompressor + + def compress(self, data): + # the uncompressed length is already encoded in the header, so + # we remove it here + return self.compressor(data)[4:] + + def decompress(self, encoded_data, uncompressed_length): + return self.decompressor(int32_pack(uncompressed_length) + encoded_data) + + def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained): + if payload_length > Segment.MAX_PAYLOAD_LENGTH: + raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH') + + header_data = payload_length + + flag_offset = self.FLAG_OFFSET + if self.compression: + header_data |= uncompressed_length << flag_offset + flag_offset += 17 + + if is_self_contained: + header_data |= 1 << flag_offset + + write_uint_le(buffer, header_data, size=self.header_length) + header_crc = compute_crc24(header_data, self.header_length) + write_uint_le(buffer, header_crc, size=CRC24_LENGTH) + + def _encode_segment(self, buffer, payload, is_self_contained): + """ + Encode a message to a single segment. + """ + uncompressed_payload = payload + uncompressed_payload_length = len(payload) + + if self.compression: + compressed_payload = self.compress(uncompressed_payload) + if len(compressed_payload) >= uncompressed_payload_length: + encoded_payload = uncompressed_payload + uncompressed_payload_length = 0 + else: + encoded_payload = compressed_payload + else: + encoded_payload = uncompressed_payload + + payload_length = len(encoded_payload) + self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained) + payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL) + buffer.write(encoded_payload) + write_uint_le(buffer, payload_crc) + + def encode(self, buffer, msg): + """ + Encode a message to one of more segments. + """ + msg_length = len(msg) + + if msg_length > Segment.MAX_PAYLOAD_LENGTH: + payloads = [] + for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH): + payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH]) + else: + payloads = [msg] + + is_self_contained = len(payloads) == 1 + for payload in payloads: + self._encode_segment(buffer, payload, is_self_contained) + + def decode_header(self, buffer): + header_data = read_uint_le(buffer, self.header_length) + + expected_header_crc = read_uint_le(buffer, CRC24_LENGTH) + actual_header_crc = compute_crc24(header_data, self.header_length) + if actual_header_crc != expected_header_crc: + raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format( + header_data, expected_header_crc, actual_header_crc)) + + payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH + header_data >>= 17 + + if self.compression: + uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH + header_data >>= 17 + else: + uncompressed_payload_length = -1 + + is_self_contained = (header_data & 1) == 1 + + return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained) + + def decode(self, buffer, header): + encoded_payload = buffer.read(header.payload_length) + expected_payload_crc = read_uint_le(buffer) + + actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL) + if actual_payload_crc != expected_payload_crc: + raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format( + expected_payload_crc, actual_payload_crc)) + + payload = encoded_payload + if self.compression and header.uncompressed_payload_length > 0: + payload = self.decompress(encoded_payload, header.uncompressed_payload_length) + + return Segment(payload, header.is_self_contained) diff --git a/cassandra/timestamps.py b/cassandra/timestamps.py new file mode 100644 index 0000000000..e2a2c1ea4c --- /dev/null +++ b/cassandra/timestamps.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains utilities for generating timestamps for client-side +timestamp specification. +""" + +import logging +import time +from threading import Lock + +log = logging.getLogger(__name__) + +class MonotonicTimestampGenerator(object): + """ + An object that, when called, returns ``int(time.time() * 1e6)`` when + possible, but, if the value returned by ``time.time`` doesn't increase, + drifts into the future and logs warnings. + Exposed configuration attributes can be configured with arguments to + ``__init__`` or by changing attributes on an initialized object. + + .. versionadded:: 3.8.0 + """ + + warn_on_drift = True + """ + If true, log warnings when timestamps drift into the future as allowed by + :attr:`warning_threshold` and :attr:`warning_interval`. + """ + + warning_threshold = 1 + """ + This object will only issue warnings when the returned timestamp drifts + more than ``warning_threshold`` seconds into the future. + Defaults to 1 second. + """ + + warning_interval = 1 + """ + This object will only issue warnings every ``warning_interval`` seconds. + Defaults to 1 second. + """ + + def __init__(self, warn_on_drift=True, warning_threshold=1, warning_interval=1): + self.lock = Lock() + with self.lock: + self.last = 0 + self._last_warn = 0 + self.warn_on_drift = warn_on_drift + self.warning_threshold = warning_threshold + self.warning_interval = warning_interval + + def _next_timestamp(self, now, last): + """ + Returns the timestamp that should be used if ``now`` is the current + time and ``last`` is the last timestamp returned by this object. + Intended for internal and testing use only; to generate timestamps, + call an instantiated ``MonotonicTimestampGenerator`` object. + + :param int now: an integer to be used as the current time, typically + representing the current time in microseconds since the UNIX epoch + :param int last: an integer representing the last timestamp returned by + this object + """ + if now > last: + self.last = now + return now + else: + self._maybe_warn(now=now) + self.last = last + 1 + return self.last + + def __call__(self): + """ + Makes ``MonotonicTimestampGenerator`` objects callable; defers + internally to _next_timestamp. + """ + with self.lock: + return self._next_timestamp(now=int(time.time() * 1e6), + last=self.last) + + def _maybe_warn(self, now): + # should be called from inside the self.lock. + diff = self.last - now + since_last_warn = now - self._last_warn + + warn = (self.warn_on_drift and + (diff >= self.warning_threshold * 1e6) and + (since_last_warn >= self.warning_interval * 1e6)) + if warn: + log.warning( + "Clock skew detected: current tick ({now}) was {diff} " + "microseconds behind the last generated timestamp " + "({last}), returned timestamps will be artificially " + "incremented to guarantee monotonicity.".format( + now=now, diff=diff, last=self.last)) + self._last_warn = now diff --git a/cassandra/tuple.pxd b/cassandra/tuple.pxd index 840cb7eb0b..b519e177bb 100644 --- a/cassandra/tuple.pxd +++ b/cassandra/tuple.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/type_codes.pxd b/cassandra/type_codes.pxd index b7e491f095..336263b83c 100644 --- a/cassandra/type_codes.pxd +++ b/cassandra/type_codes.pxd @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py index 2f0ce8f5a0..eab9a3344a 100644 --- a/cassandra/type_codes.py +++ b/cassandra/type_codes.py @@ -25,6 +25,11 @@ 0x000E Varint 0x000F Timeuuid 0x0010 Inet + 0x0011 SimpleDateType + 0x0012 TimeType + 0x0013 ShortType + 0x0014 ByteType + 0x0015 DurationType 0x0020 List: the value is an [option], representing the type of the elements of the list. 0x0021 Map: the value is two [option], representing the types of the @@ -54,9 +59,9 @@ TimeType = 0x0012 ShortType = 0x0013 ByteType = 0x0014 +DurationType = 0x0015 ListType = 0x0020 MapType = 0x0021 SetType = 0x0022 UserType = 0x0030 TupleType = 0x0031 - diff --git a/cassandra/util.py b/cassandra/util.py index ab6968e035..f973912574 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,19 +14,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import with_statement +from _weakref import ref import calendar +from collections import OrderedDict +from collections.abc import Mapping import datetime +from functools import total_ordering +from itertools import chain +import keyword +import logging +import pickle import random -import six -import uuid +import re +import socket import sys +import time +import uuid + +_HAS_GEOMET = True +try: + from geomet import wkt +except: + _HAS_GEOMET = False + + +from cassandra import DriverException + +DATETIME_EPOC = datetime.datetime(1970, 1, 1).replace(tzinfo=None) +UTC_DATETIME_EPOC = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) -DATETIME_EPOC = datetime.datetime(1970, 1, 1) +_nan = float('nan') + +log = logging.getLogger(__name__) assert sys.byteorder in ('little', 'big') is_little_endian = sys.byteorder == 'little' + def datetime_from_timestamp(timestamp): """ Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner. @@ -37,6 +63,28 @@ def datetime_from_timestamp(timestamp): return dt +def utc_datetime_from_ms_timestamp(timestamp): + """ + Creates a UTC datetime from a timestamp in milliseconds. See + :meth:`datetime_from_timestamp`. + + Raises an `OverflowError` if the timestamp is out of range for + :class:`~datetime.datetime`. + + :param timestamp: timestamp, in milliseconds + """ + return UTC_DATETIME_EPOC + datetime.timedelta(milliseconds=timestamp) + + +def ms_timestamp_from_datetime(dt): + """ + Converts a datetime to a timestamp expressed in milliseconds. + + :param dt: a :class:`datetime.datetime` + """ + return int(round((dt - UTC_DATETIME_EPOC).total_seconds() * 1000)) + + def unix_time_from_uuid1(uuid_arg): """ Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision @@ -135,145 +183,40 @@ def uuid_from_time(time_arg, node=None, clock_seq=None): """ The highest possible TimeUUID, as sorted by Cassandra. """ -try: - from collections import OrderedDict -except ImportError: - # OrderedDict from Python 2.7+ - - # Copyright (c) 2009 Raymond Hettinger - # - # Permission is hereby granted, free of charge, to any person - # obtaining a copy of this software and associated documentation files - # (the "Software"), to deal in the Software without restriction, - # including without limitation the rights to use, copy, modify, merge, - # publish, distribute, sublicense, and/or sell copies of the Software, - # and to permit persons to whom the Software is furnished to do so, - # subject to the following conditions: - # - # The above copyright notice and this permission notice shall be - # included in all copies or substantial portions of the Software. - # - # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - # OTHER DEALINGS IN THE SOFTWARE. - from UserDict import DictMixin - - class OrderedDict(dict, DictMixin): # noqa - """ A dictionary which maintains the insertion order of keys. """ - - def __init__(self, *args, **kwds): - """ A dictionary which maintains the insertion order of keys. """ - - if len(args) > 1: - raise TypeError('expected at most 1 arguments, got %d' % len(args)) - try: - self.__end - except AttributeError: - self.clear() - self.update(*args, **kwds) - - def clear(self): - self.__end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.__map = {} # key --> [key, prev, next] - dict.clear(self) - - def __setitem__(self, key, value): - if key not in self: - end = self.__end - curr = end[1] - curr[2] = end[1] = self.__map[key] = [key, curr, end] - dict.__setitem__(self, key, value) - - def __delitem__(self, key): - dict.__delitem__(self, key) - key, prev, next = self.__map.pop(key) - prev[2] = next - next[1] = prev - - def __iter__(self): - end = self.__end - curr = end[2] - while curr is not end: - yield curr[0] - curr = curr[2] - - def __reversed__(self): - end = self.__end - curr = end[1] - while curr is not end: - yield curr[0] - curr = curr[1] - - def popitem(self, last=True): - if not self: - raise KeyError('dictionary is empty') - if last: - key = next(reversed(self)) - else: - key = next(iter(self)) - value = self.pop(key) - return key, value - - def __reduce__(self): - items = [[k, self[k]] for k in self] - tmp = self.__map, self.__end - del self.__map, self.__end - inst_dict = vars(self).copy() - self.__map, self.__end = tmp - if inst_dict: - return (self.__class__, (items,), inst_dict) - return self.__class__, (items,) - - def keys(self): - return list(self) - - setdefault = DictMixin.setdefault - update = DictMixin.update - pop = DictMixin.pop - values = DictMixin.values - items = DictMixin.items - iterkeys = DictMixin.iterkeys - itervalues = DictMixin.itervalues - iteritems = DictMixin.iteritems - - def __repr__(self): - if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, self.items()) - - def copy(self): - return self.__class__(self) - - @classmethod - def fromkeys(cls, iterable, value=None): - d = cls() - for key in iterable: - d[key] = value - return d - - def __eq__(self, other): - if isinstance(other, OrderedDict): - if len(self) != len(other): - return False - for p, q in zip(self.items(), other.items()): - if p != q: - return False - return True - return dict.__eq__(self, other) - - def __ne__(self, other): - return not self == other - - -# WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset) +def _addrinfo_or_none(contact_point, port): + """ + A helper function that wraps socket.getaddrinfo and returns None + when it fails to, e.g. resolve one of the hostnames. Used to address + PYTHON-895. + """ + try: + value = socket.getaddrinfo(contact_point, port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + return value + except socket.gaierror: + log.debug('Could not resolve hostname "{}" ' + 'with port {}'.format(contact_point, port)) + return None -from _weakref import ref + +def _addrinfo_to_ip_strings(addrinfo): + """ + Helper function that consumes the data output by socket.getaddrinfo and + extracts the IP address from the sockaddr portion of the result. + + Since this is meant to be used in conjunction with _addrinfo_or_none, + this will pass None and EndPoint instances through unaffected. + """ + if addrinfo is None: + return None + return [(entry[4][0], entry[4][1]) for entry in addrinfo] + + +def _resolve_contact_points_to_string_map(contact_points): + return OrderedDict( + ('{cp}:{port}'.format(cp=cp, port=port), _addrinfo_to_ip_strings(_addrinfo_or_none(cp, port))) + for cp, port in contact_points + ) class _IterationGuard(object): @@ -486,9 +429,6 @@ def isdisjoint(self, other): return len(self.intersection(other)) == 0 -from bisect import bisect_left - - class SortedSet(object): ''' A sorted set based on sorted list @@ -592,7 +532,7 @@ def __ixor__(self, other): return self def __contains__(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) return i < len(self._items) and self._items[i] == item def __delitem__(self, i): @@ -602,7 +542,7 @@ def __delslice__(self, i, j): del self._items[i:j] def add(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) if i < len(self._items): if self._items[i] != item: self._items.insert(i, item) @@ -636,7 +576,7 @@ def pop(self): return self._items.pop() def remove(self, item): - i = bisect_left(self._items, item) + i = self._find_insertion(item) if i < len(self._items): if self._items[i] == item: self._items.pop(i) @@ -647,18 +587,8 @@ def union(self, *others): union = sortedset() union._items = list(self._items) for other in others: - if isinstance(other, self.__class__): - i = 0 - for item in other._items: - i = bisect_left(union._items, item, i) - if i < len(union._items): - if item != union._items[i]: - union._items.insert(i, item) - else: - union._items.append(item) - else: - for item in other: - union.add(item) + for item in other: + union.add(item) return union def intersection(self, *others): @@ -684,50 +614,52 @@ def symmetric_difference(self, other): def _diff(self, other): diff = sortedset() - if isinstance(other, self.__class__): - i = 0 - for item in self._items: - i = bisect_left(other._items, item, i) - if i < len(other._items): - if item != other._items[i]: - diff._items.append(item) - else: - diff._items.append(item) - else: - for item in self._items: - if item not in other: - diff.add(item) + for item in self._items: + if item not in other: + diff.add(item) return diff def _intersect(self, other): isect = sortedset() - if isinstance(other, self.__class__): - i = 0 - for item in self._items: - i = bisect_left(other._items, item, i) - if i < len(other._items): - if item == other._items[i]: - isect._items.append(item) - else: - break - else: - for item in self._items: - if item in other: - isect.add(item) + for item in self._items: + if item in other: + isect.add(item) return isect -sortedset = SortedSet # backwards-compatibility - + def _find_insertion(self, x): + # this uses bisect_left algorithm unless it has elements it can't compare, + # in which case it defaults to grouping non-comparable items at the beginning or end, + # and scanning sequentially to find an insertion point + a = self._items + lo = 0 + hi = len(a) + try: + while lo < hi: + mid = (lo + hi) // 2 + if a[mid] < x: lo = mid + 1 + else: hi = mid + except TypeError: + # could not compare a[mid] with x + # start scanning to find insertion point while swallowing type errors + lo = 0 + compared_one = False # flag is used to determine whether un-comparables are grouped at the front or back + while lo < hi: + try: + if a[lo] == x or a[lo] >= x: break + compared_one = True + except TypeError: + if compared_one: break + lo += 1 + return lo -from collections import Mapping -from six.moves import cPickle +sortedset = SortedSet # backwards-compatibility class OrderedMap(Mapping): ''' An ordered map that accepts non-hashable types for keys. It also maintains the insertion order of items, behaving as OrderedDict in that regard. These maps - are constructed and read just as normal mapping types, exept that they may + are constructed and read just as normal mapping types, except that they may contain arbitrary collections and other non-hashable items as keys:: >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), @@ -738,7 +670,7 @@ class OrderedMap(Mapping): ['value', 'value2'] These constructs are needed to support nested collections in Cassandra 2.1.3+, - where frozen collections can be specified as parameters to others\*:: + where frozen collections can be specified as parameters to others:: CREATE TABLE example ( ... @@ -748,11 +680,6 @@ class OrderedMap(Mapping): This class derives from the (immutable) Mapping API. Objects in these maps are not intended be modified. - - \* Note: Because of the way Cassandra encodes nested types, when using the - driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3 - or higher. - ''' def __init__(self, *args, **kwargs): @@ -770,7 +697,7 @@ def __init__(self, *args, **kwargs): for k, v in e: self._insert(k, v) - for k, v in six.iteritems(kwargs): + for k, v in kwargs.items(): self._insert(k, v) def _insert(self, key, value): @@ -836,7 +763,7 @@ def popitem(self): raise KeyError() def _serialize_key(self, key): - return cPickle.dumps(key) + return pickle.dumps(key) class OrderedMapSerializedKey(OrderedMap): @@ -854,13 +781,7 @@ def _serialize_key(self, key): return self.cass_key_type.serialize(key, self.protocol_version) -import datetime -import time - -if six.PY3: - long = int - - +@total_ordering class Time(object): ''' Idealized time, independent of day. @@ -881,15 +802,15 @@ def __init__(self, value): """ Initializer value can be: - - integer_type: absolute nanoseconds in the day - - datetime.time: built-in time - - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" + - integer_type: absolute nanoseconds in the day + - datetime.time: built-in time + - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" """ - if isinstance(value, six.integer_types): + if isinstance(value, int): self._from_timestamp(value) elif isinstance(value, datetime.time): self._from_time(value) - elif isinstance(value, six.string_types): + elif isinstance(value, str): self._from_timestring(value) else: raise TypeError('Time arguments must be a whole number, datetime.time, or string') @@ -924,6 +845,13 @@ def nanosecond(self): """ return self.nanosecond_time % Time.SECOND + def time(self): + """ + Return a built-in datetime.time (nanosecond precision truncated to micros). + """ + return datetime.time(hour=self.hour, minute=self.minute, second=self.second, + microsecond=self.nanosecond // Time.MICRO) + def _from_timestamp(self, t): if t >= Time.DAY: raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) @@ -958,14 +886,19 @@ def __eq__(self, other): if isinstance(other, Time): return self.nanosecond_time == other.nanosecond_time - if isinstance(other, six.integer_types): + if isinstance(other, int): return self.nanosecond_time == other return self.nanosecond_time % Time.MICRO == 0 and \ datetime.time(hour=self.hour, minute=self.minute, second=self.second, microsecond=self.nanosecond // Time.MICRO) == other + def __ne__(self, other): + return not self.__eq__(other) + def __lt__(self, other): + if not isinstance(other, Time): + return NotImplemented return self.nanosecond_time < other.nanosecond_time def __repr__(self): @@ -976,6 +909,7 @@ def __str__(self): self.second, self.nanosecond) +@total_ordering class Date(object): ''' Idealized date: year, month, day @@ -997,15 +931,15 @@ def __init__(self, value): """ Initializer value can be: - - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. - - datetime.date: built-in date - - string_type: a string time of the form "yyyy-mm-dd" + - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. + - datetime.date: built-in date + - string_type: a string time of the form "yyyy-mm-dd" """ - if isinstance(value, six.integer_types): + if isinstance(value, int): self.days_from_epoch = value elif isinstance(value, (datetime.date, datetime.datetime)): self._from_timetuple(value.timetuple()) - elif isinstance(value, six.string_types): + elif isinstance(value, str): self._from_datestring(value) else: raise TypeError('Date arguments must be a whole number, datetime.date, or string') @@ -1045,7 +979,7 @@ def __eq__(self, other): if isinstance(other, Date): return self.days_from_epoch == other.days_from_epoch - if isinstance(other, six.integer_types): + if isinstance(other, int): return self.days_from_epoch == other try: @@ -1053,7 +987,12 @@ def __eq__(self, other): except Exception: return False + def __ne__(self, other): + return not self.__eq__(other) + def __lt__(self, other): + if not isinstance(other, Date): + return NotImplemented return self.days_from_epoch < other.days_from_epoch def __repr__(self): @@ -1067,107 +1006,799 @@ def __str__(self): # If we overflow datetime.[MIN|MAX] return str(self.days_from_epoch) -import socket -if hasattr(socket, 'inet_pton'): - inet_pton = socket.inet_pton - inet_ntop = socket.inet_ntop -else: + +inet_pton = socket.inet_pton +inet_ntop = socket.inet_ntop + + +# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic +def _positional_rename_invalid_identifiers(field_names): + names_out = list(field_names) + for index, name in enumerate(field_names): + if (not all(c.isalnum() or c == '_' for c in name) + or keyword.iskeyword(name) + or not name + or name[0].isdigit() + or name.startswith('_')): + names_out[index] = 'field_%d_' % index + return names_out + + +def _sanitize_identifiers(field_names): + names_out = _positional_rename_invalid_identifiers(field_names) + if len(names_out) != len(set(names_out)): + observed_names = set() + for index, name in enumerate(names_out): + while names_out[index] in observed_names: + names_out[index] = "%s_" % (names_out[index],) + observed_names.add(names_out[index]) + return names_out + + +def list_contents_to_tuple(to_convert): + if isinstance(to_convert, list): + for n, i in enumerate(to_convert): + if isinstance(to_convert[n], list): + to_convert[n] = tuple(to_convert[n]) + return tuple(to_convert) + else: + return to_convert + + +class Point(object): """ - Windows doesn't have socket.inet_pton and socket.inet_ntop until Python 3.4 - This is an alternative impl using ctypes, based on this win_inet_pton project: - https://github.com/hickeroar/win_inet_pton + Represents a point geometry for DSE """ - import ctypes - class sockaddr(ctypes.Structure): + x = None + """ + x coordinate of the point + """ + + y = None + """ + y coordinate of the point + """ + + def __init__(self, x=_nan, y=_nan): + self.x = x + self.y = y + + def __eq__(self, other): + return isinstance(other, Point) and self.x == other.x and self.y == other.y + + def __hash__(self): + return hash((self.x, self.y)) + + def __str__(self): + """ + Well-known text representation of the point + """ + return "POINT (%r %r)" % (self.x, self.y) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.x, self.y) + + @staticmethod + def from_wkt(s): + """ + Parse a Point geometry from a wkt string and return a new Point object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'Point': + raise ValueError("Invalid WKT geometry type. Expected 'Point', got '{0}': '{1}'".format(geom['type'], s)) + + coords = geom['coordinates'] + if len(coords) < 2: + x = y = _nan + else: + x = coords[0] + y = coords[1] + + return Point(x=x, y=y) + + +class LineString(object): + """ + Represents a linestring geometry for DSE + """ + + coords = None + """ + Tuple of (x, y) coordinates in the linestring + """ + def __init__(self, coords=tuple()): + """ + 'coords`: a sequence of (x, y) coordinates of points in the linestring """ - Shared struct for ipv4 and ipv6. + self.coords = tuple(coords) - https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx + def __eq__(self, other): + return isinstance(other, LineString) and self.coords == other.coords - ``__pad1`` always covers the port. + def __hash__(self): + return hash(self.coords) - When being used for ``sockaddr_in6``, ``ipv4_addr`` actually covers ``sin6_flowinfo``, resulting - in proper alignment for ``ipv6_addr``. + def __str__(self): """ - _fields_ = [("sa_family", ctypes.c_short), - ("__pad1", ctypes.c_ushort), - ("ipv4_addr", ctypes.c_byte * 4), - ("ipv6_addr", ctypes.c_byte * 16), - ("__pad2", ctypes.c_ulong)] - - if hasattr(ctypes, 'windll'): - WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA - WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA - else: - def not_windows(*args): - raise Exception("IPv6 addresses cannot be handled on Windows. " - "Missing ctypes.windll") - WSAStringToAddressA = not_windows - WSAAddressToStringA = not_windows - - def inet_pton(address_family, ip_string): - if address_family == socket.AF_INET: - return socket.inet_aton(ip_string) - - addr = sockaddr() - addr.sa_family = address_family - addr_size = ctypes.c_int(ctypes.sizeof(addr)) - - if WSAStringToAddressA( - ip_string, - address_family, - None, - ctypes.byref(addr), - ctypes.byref(addr_size) - ) != 0: - raise socket.error(ctypes.FormatError()) - - if address_family == socket.AF_INET6: - return ctypes.string_at(addr.ipv6_addr, 16) - - raise socket.error('unknown address family') - - def inet_ntop(address_family, packed_ip): - if address_family == socket.AF_INET: - return socket.inet_ntoa(packed_ip) - - addr = sockaddr() - addr.sa_family = address_family - addr_size = ctypes.c_int(ctypes.sizeof(addr)) - ip_string = ctypes.create_string_buffer(128) - ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string)) - - if address_family == socket.AF_INET6: - if len(packed_ip) != ctypes.sizeof(addr.ipv6_addr): - raise socket.error('packed IP wrong length for inet_ntoa') - ctypes.memmove(addr.ipv6_addr, packed_ip, 16) + Well-known text representation of the LineString + """ + if not self.coords: + return "LINESTRING EMPTY" + return "LINESTRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.coords) + + @staticmethod + def from_wkt(s): + """ + Parse a LineString geometry from a wkt string and return a new LineString object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'LineString': + raise ValueError("Invalid WKT geometry type. Expected 'LineString', got '{0}': '{1}'".format(geom['type'], s)) + + geom['coordinates'] = list_contents_to_tuple(geom['coordinates']) + + return LineString(coords=geom['coordinates']) + + +class _LinearRing(object): + # no validation, no implicit closing; just used for poly composition, to + # mimic that of shapely.geometry.Polygon + def __init__(self, coords=tuple()): + self.coords = list_contents_to_tuple(coords) + + def __eq__(self, other): + return isinstance(other, _LinearRing) and self.coords == other.coords + + def __hash__(self): + return hash(self.coords) + + def __str__(self): + if not self.coords: + return "LINEARRING EMPTY" + return "LINEARRING (%s)" % ', '.join("%r %r" % (x, y) for x, y in self.coords) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.coords) + + +class Polygon(object): + """ + Represents a polygon geometry for DSE + """ + + exterior = None + """ + _LinearRing representing the exterior of the polygon + """ + + interiors = None + """ + Tuple of _LinearRings representing interior holes in the polygon + """ + + def __init__(self, exterior=tuple(), interiors=None): + """ + 'exterior`: a sequence of (x, y) coordinates of points in the linestring + `interiors`: None, or a sequence of sequences or (x, y) coordinates of points describing interior linear rings + """ + self.exterior = _LinearRing(exterior) + self.interiors = tuple(_LinearRing(e) for e in interiors) if interiors else tuple() + + def __eq__(self, other): + return isinstance(other, Polygon) and self.exterior == other.exterior and self.interiors == other.interiors + + def __hash__(self): + return hash((self.exterior, self.interiors)) + + def __str__(self): + """ + Well-known text representation of the polygon + """ + if not self.exterior.coords: + return "POLYGON EMPTY" + rings = [ring.coords for ring in chain((self.exterior,), self.interiors)] + rings = ["(%s)" % ', '.join("%r %r" % (x, y) for x, y in ring) for ring in rings] + return "POLYGON (%s)" % ', '.join(rings) + + def __repr__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.exterior.coords, [ring.coords for ring in self.interiors]) + + @staticmethod + def from_wkt(s): + """ + Parse a Polygon geometry from a wkt string and return a new Polygon object. + """ + if not _HAS_GEOMET: + raise DriverException("Geomet is required to deserialize a wkt geometry.") + + try: + geom = wkt.loads(s) + except ValueError: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + if geom['type'] != 'Polygon': + raise ValueError("Invalid WKT geometry type. Expected 'Polygon', got '{0}': '{1}'".format(geom['type'], s)) + + coords = geom['coordinates'] + exterior = coords[0] if len(coords) > 0 else tuple() + interiors = coords[1:] if len(coords) > 1 else None + + return Polygon(exterior=exterior, interiors=interiors) + + +_distance_wkt_pattern = re.compile("distance *\\( *\\( *([\\d\\.-]+) *([\\d+\\.-]+) *\\) *([\\d+\\.-]+) *\\) *$", re.IGNORECASE) + + +class Distance(object): + """ + Represents a Distance geometry for DSE + """ + + x = None + """ + x coordinate of the center point + """ + + y = None + """ + y coordinate of the center point + """ + + radius = None + """ + radius to represent the distance from the center point + """ + + def __init__(self, x=_nan, y=_nan, radius=_nan): + self.x = x + self.y = y + self.radius = radius + + def __eq__(self, other): + return isinstance(other, Distance) and self.x == other.x and self.y == other.y and self.radius == other.radius + + def __hash__(self): + return hash((self.x, self.y, self.radius)) + + def __str__(self): + """ + Well-known text representation of the point + """ + return "DISTANCE ((%r %r) %r)" % (self.x, self.y, self.radius) + + def __repr__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.x, self.y, self.radius) + + @staticmethod + def from_wkt(s): + """ + Parse a Distance geometry from a wkt string and return a new Distance object. + """ + + distance_match = _distance_wkt_pattern.match(s) + + if distance_match is None: + raise ValueError("Invalid WKT geometry: '{0}'".format(s)) + + x, y, radius = distance_match.groups() + return Distance(x, y, radius) + + +class Duration(object): + """ + Cassandra Duration Type + """ + + months = 0 + "" + days = 0 + "" + nanoseconds = 0 + "" + + def __init__(self, months=0, days=0, nanoseconds=0): + self.months = months + self.days = days + self.nanoseconds = nanoseconds + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.months == other.months and self.days == other.days and self.nanoseconds == other.nanoseconds + + def __repr__(self): + return "Duration({0}, {1}, {2})".format(self.months, self.days, self.nanoseconds) + + def __str__(self): + has_negative_values = self.months < 0 or self.days < 0 or self.nanoseconds < 0 + return '%s%dmo%dd%dns' % ( + '-' if has_negative_values else '', + abs(self.months), + abs(self.days), + abs(self.nanoseconds) + ) + + +class DateRangePrecision(object): + """ + An "enum" representing the valid values for :attr:`DateRange.precision`. + """ + YEAR = 'YEAR' + """ + """ + + MONTH = 'MONTH' + """ + """ + + DAY = 'DAY' + """ + """ + + HOUR = 'HOUR' + """ + """ + + MINUTE = 'MINUTE' + """ + """ + + SECOND = 'SECOND' + """ + """ + + MILLISECOND = 'MILLISECOND' + """ + """ + + PRECISIONS = (YEAR, MONTH, DAY, HOUR, + MINUTE, SECOND, MILLISECOND) + """ + """ + + @classmethod + def _to_int(cls, precision): + return cls.PRECISIONS.index(precision.upper()) + + @classmethod + def _round_to_precision(cls, ms, precision, default_dt): + try: + dt = utc_datetime_from_ms_timestamp(ms) + except OverflowError: + return ms + precision_idx = cls._to_int(precision) + replace_kwargs = {} + if precision_idx <= cls._to_int(DateRangePrecision.YEAR): + replace_kwargs['month'] = default_dt.month + if precision_idx <= cls._to_int(DateRangePrecision.MONTH): + replace_kwargs['day'] = default_dt.day + if precision_idx <= cls._to_int(DateRangePrecision.DAY): + replace_kwargs['hour'] = default_dt.hour + if precision_idx <= cls._to_int(DateRangePrecision.HOUR): + replace_kwargs['minute'] = default_dt.minute + if precision_idx <= cls._to_int(DateRangePrecision.MINUTE): + replace_kwargs['second'] = default_dt.second + if precision_idx <= cls._to_int(DateRangePrecision.SECOND): + # truncate to nearest 1000, so we deal in ms, not us + replace_kwargs['microsecond'] = (default_dt.microsecond // 1000) * 1000 + if precision_idx == cls._to_int(DateRangePrecision.MILLISECOND): + replace_kwargs['microsecond'] = int(round(dt.microsecond, -3)) + return ms_timestamp_from_datetime(dt.replace(**replace_kwargs)) + + @classmethod + def round_up_to_precision(cls, ms, precision): + # PYTHON-912: this is the only case in which we can't take as upper bound + # datetime.datetime.max because the month from ms may be February, and we'd + # be setting 31 as the month day + if precision == cls.MONTH: + date_ms = utc_datetime_from_ms_timestamp(ms) + upper_date = datetime.datetime.max.replace(year=date_ms.year, month=date_ms.month, + day=calendar.monthrange(date_ms.year, date_ms.month)[1]) else: - raise socket.error('unknown address family') + upper_date = datetime.datetime.max + return cls._round_to_precision(ms, precision, upper_date) - if WSAAddressToStringA( - ctypes.byref(addr), - addr_size, - None, - ip_string, - ctypes.byref(ip_string_size) - ) != 0: - raise socket.error(ctypes.FormatError()) + @classmethod + def round_down_to_precision(cls, ms, precision): + return cls._round_to_precision(ms, precision, datetime.datetime.min) - return ip_string[:ip_string_size.value - 1] +@total_ordering +class DateRangeBound(object): + """DateRangeBound(value, precision) + Represents a single date value and its precision for :class:`DateRange`. -import keyword + .. attribute:: milliseconds + Integer representing milliseconds since the UNIX epoch. May be negative. -# similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic -def _positional_rename_invalid_identifiers(field_names): - names_out = list(field_names) - for index, name in enumerate(field_names): - if (not all(c.isalnum() or c == '_' for c in name) - or keyword.iskeyword(name) - or not name - or name[0].isdigit() - or name.startswith('_')): - names_out[index] = 'field_%d_' % index - return names_out + .. attribute:: precision + + String representing the precision of a bound. Must be a valid + :class:`DateRangePrecision` member. + + :class:`DateRangeBound` uses a millisecond offset from the UNIX epoch to + allow :class:`DateRange` to represent values `datetime.datetime` cannot. + For such values, string representions will show this offset rather than the + CQL representation. + """ + milliseconds = None + precision = None + + def __init__(self, value, precision): + """ + :param value: a value representing ms since the epoch. Accepts an + integer or a datetime. + :param precision: a string representing precision + """ + if precision is not None: + try: + self.precision = precision.upper() + except AttributeError: + raise TypeError('precision must be a string; got %r' % precision) + + if value is None: + milliseconds = None + elif isinstance(value, int): + milliseconds = value + elif isinstance(value, datetime.datetime): + value = value.replace( + microsecond=int(round(value.microsecond, -3)) + ) + milliseconds = ms_timestamp_from_datetime(value) + else: + raise ValueError('%r is not a valid value for DateRangeBound' % value) + + self.milliseconds = milliseconds + self.validate() + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.milliseconds == other.milliseconds and + self.precision == other.precision) + + def __lt__(self, other): + return ((str(self.milliseconds), str(self.precision)) < + (str(other.milliseconds), str(other.precision))) + + def datetime(self): + """ + Return :attr:`milliseconds` as a :class:`datetime.datetime` if possible. + Raises an `OverflowError` if the value is out of range. + """ + return utc_datetime_from_ms_timestamp(self.milliseconds) + + def validate(self): + attrs = self.milliseconds, self.precision + if attrs == (None, None): + return + if None in attrs: + raise TypeError( + ("%s.datetime and %s.precision must not be None unless both " + "are None; Got: %r") % (self.__class__.__name__, + self.__class__.__name__, + self) + ) + if self.precision not in DateRangePrecision.PRECISIONS: + raise ValueError( + "%s.precision: expected value in %r; got %r" % ( + self.__class__.__name__, + DateRangePrecision.PRECISIONS, + self.precision + ) + ) + + @classmethod + def from_value(cls, value): + """ + Construct a new :class:`DateRangeBound` from a given value. If + possible, use the `value['milliseconds']` and `value['precision']` keys + of the argument. Otherwise, use the argument as a `(milliseconds, + precision)` iterable. + + :param value: a dictlike or iterable object + """ + if isinstance(value, cls): + return value + + # if possible, use as a mapping + try: + milliseconds, precision = value.get('milliseconds'), value.get('precision') + except AttributeError: + milliseconds = precision = None + if milliseconds is not None and precision is not None: + return DateRangeBound(value=milliseconds, precision=precision) + + # otherwise, use as an iterable + return DateRangeBound(*value) + + def round_up(self): + if self.milliseconds is None or self.precision is None: + return self + self.milliseconds = DateRangePrecision.round_up_to_precision( + self.milliseconds, self.precision + ) + return self + + def round_down(self): + if self.milliseconds is None or self.precision is None: + return self + self.milliseconds = DateRangePrecision.round_down_to_precision( + self.milliseconds, self.precision + ) + return self + + _formatter_map = { + DateRangePrecision.YEAR: '%Y', + DateRangePrecision.MONTH: '%Y-%m', + DateRangePrecision.DAY: '%Y-%m-%d', + DateRangePrecision.HOUR: '%Y-%m-%dT%HZ', + DateRangePrecision.MINUTE: '%Y-%m-%dT%H:%MZ', + DateRangePrecision.SECOND: '%Y-%m-%dT%H:%M:%SZ', + DateRangePrecision.MILLISECOND: '%Y-%m-%dT%H:%M:%S', + } + + def __str__(self): + if self == OPEN_BOUND: + return '*' + + try: + dt = self.datetime() + except OverflowError: + return '%sms' % (self.milliseconds,) + + formatted = dt.strftime(self._formatter_map[self.precision]) + + if self.precision == DateRangePrecision.MILLISECOND: + # we'd like to just format with '%Y-%m-%dT%H:%M:%S.%fZ', but %f + # gives us more precision than we want, so we strftime up to %S and + # do the rest ourselves + return '%s.%03dZ' % (formatted, dt.microsecond / 1000) + + return formatted + + def __repr__(self): + return '%s(milliseconds=%r, precision=%r)' % ( + self.__class__.__name__, self.milliseconds, self.precision + ) + + +OPEN_BOUND = DateRangeBound(value=None, precision=None) +""" +Represents `*`, an open value or bound for :class:`DateRange`. +""" + + +@total_ordering +class DateRange(object): + """DateRange(lower_bound=None, upper_bound=None, value=None) + DSE DateRange Type + + .. attribute:: lower_bound + + :class:`~DateRangeBound` representing the lower bound of a bounded range. + + .. attribute:: upper_bound + + :class:`~DateRangeBound` representing the upper bound of a bounded range. + + .. attribute:: value + + :class:`~DateRangeBound` representing the value of a single-value range. + + As noted in its documentation, :class:`DateRangeBound` uses a millisecond + offset from the UNIX epoch to allow :class:`DateRange` to represent values + `datetime.datetime` cannot. For such values, string representions will show + this offset rather than the CQL representation. + """ + lower_bound = None + upper_bound = None + value = None + + def __init__(self, lower_bound=None, upper_bound=None, value=None): + """ + :param lower_bound: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as a + :attr:`lower_bound`. Mutually exclusive with `value`. If + `upper_bound` is specified and this is not, the :attr:`lower_bound` + will be open. + :param upper_bound: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as a + :attr:`upper_bound`. Mutually exclusive with `value`. If + `lower_bound` is specified and this is not, the :attr:`upper_bound` + will be open. + :param value: a :class:`DateRangeBound` or object accepted by + :meth:`DateRangeBound.from_value` to be used as :attr:`value`. Mutually + exclusive with `lower_bound` and `lower_bound`. + """ + + # if necessary, transform non-None args to DateRangeBounds + lower_bound = (DateRangeBound.from_value(lower_bound).round_down() + if lower_bound else lower_bound) + upper_bound = (DateRangeBound.from_value(upper_bound).round_up() + if upper_bound else upper_bound) + value = (DateRangeBound.from_value(value).round_down() + if value else value) + + # if we're using a 2-ended range but one bound isn't specified, specify + # an open bound + if lower_bound is None and upper_bound is not None: + lower_bound = OPEN_BOUND + if upper_bound is None and lower_bound is not None: + upper_bound = OPEN_BOUND + + self.lower_bound, self.upper_bound, self.value = ( + lower_bound, upper_bound, value + ) + self.validate() + + def validate(self): + if self.value is None: + if self.lower_bound is None or self.upper_bound is None: + raise ValueError( + '%s instances where value attribute is None must set ' + 'lower_bound or upper_bound; got %r' % ( + self.__class__.__name__, + self + ) + ) + else: # self.value is not None + if self.lower_bound is not None or self.upper_bound is not None: + raise ValueError( + '%s instances where value attribute is not None must not ' + 'set lower_bound or upper_bound; got %r' % ( + self.__class__.__name__, + self + ) + ) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.lower_bound == other.lower_bound and + self.upper_bound == other.upper_bound and + self.value == other.value) + + def __lt__(self, other): + return ((str(self.lower_bound), str(self.upper_bound), str(self.value)) < + (str(other.lower_bound), str(other.upper_bound), str(other.value))) + + def __str__(self): + if self.value: + return str(self.value) + else: + return '[%s TO %s]' % (self.lower_bound, self.upper_bound) + + def __repr__(self): + return '%s(lower_bound=%r, upper_bound=%r, value=%r)' % ( + self.__class__.__name__, + self.lower_bound, self.upper_bound, self.value + ) + + +@total_ordering +class Version(object): + """ + Internal minimalist class to compare versions. + A valid version is: .... + + TODO: when python2 support is removed, use packaging.version. + """ + + _version = None + major = None + minor = 0 + patch = 0 + build = 0 + prerelease = 0 + + def __init__(self, version): + self._version = version + if '-' in version: + version_without_prerelease, self.prerelease = version.split('-', 1) + else: + version_without_prerelease = version + parts = list(reversed(version_without_prerelease.split('.'))) + if len(parts) > 4: + prerelease_string = "-{}".format(self.prerelease) if self.prerelease else "" + log.warning("Unrecognized version: {}. Only 4 components plus prerelease are supported. " + "Assuming version as {}{}".format(version, '.'.join(parts[:-5:-1]), prerelease_string)) + + try: + self.major = int(parts.pop()) + except ValueError as e: + raise ValueError( + "Couldn't parse version {}. Version should start with a number".format(version))\ + .with_traceback(e.__traceback__) + try: + self.minor = int(parts.pop()) if parts else 0 + self.patch = int(parts.pop()) if parts else 0 + + if parts: # we have a build version + build = parts.pop() + try: + self.build = int(build) + except ValueError: + self.build = build + except ValueError: + assumed_version = "{}.{}.{}.{}-{}".format(self.major, self.minor, self.patch, self.build, self.prerelease) + log.warning("Unrecognized version {}. Assuming version as {}".format(version, assumed_version)) + + def __hash__(self): + return self._version + + def __repr__(self): + version_string = "Version({0}, {1}, {2}".format(self.major, self.minor, self.patch) + if self.build: + version_string += ", {}".format(self.build) + if self.prerelease: + version_string += ", {}".format(self.prerelease) + version_string += ")" + + return version_string + + def __str__(self): + return self._version + + @staticmethod + def _compare_version_part(version, other_version, cmp): + if not (isinstance(version, int) and + isinstance(other_version, int)): + version = str(version) + other_version = str(other_version) + + return cmp(version, other_version) + + def __eq__(self, other): + if not isinstance(other, Version): + return NotImplemented + + return (self.major == other.major and + self.minor == other.minor and + self.patch == other.patch and + self._compare_version_part(self.build, other.build, lambda s, o: s == o) and + self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s == o) + ) + + def __gt__(self, other): + if not isinstance(other, Version): + return NotImplemented + + is_major_ge = self.major >= other.major + is_minor_ge = self.minor >= other.minor + is_patch_ge = self.patch >= other.patch + is_build_gt = self._compare_version_part(self.build, other.build, lambda s, o: s > o) + is_build_ge = self._compare_version_part(self.build, other.build, lambda s, o: s >= o) + + # By definition, a prerelease comes BEFORE the actual release, so if a version + # doesn't have a prerelease, it's automatically greater than anything that does + if self.prerelease and not other.prerelease: + is_prerelease_gt = False + elif other.prerelease and not self.prerelease: + is_prerelease_gt = True + else: + is_prerelease_gt = self._compare_version_part(self.prerelease, other.prerelease, lambda s, o: s > o) \ + + return (self.major > other.major or + (is_major_ge and self.minor > other.minor) or + (is_major_ge and is_minor_ge and self.patch > other.patch) or + (is_major_ge and is_minor_ge and is_patch_ge and is_build_gt) or + (is_major_ge and is_minor_ge and is_patch_ge and is_build_ge and is_prerelease_gt) + ) diff --git a/docs.yaml b/docs.yaml new file mode 100644 index 0000000000..63269a3001 --- /dev/null +++ b/docs.yaml @@ -0,0 +1,117 @@ +title: DataStax Python Driver +summary: DataStax Python Driver for Apache Cassandra® +output: docs/_build/ +swiftype_drivers: pythondrivers +sections: + - title: N/A + prefix: / + type: sphinx + directory: docs + virtualenv_init: | + set -x + CASS_DRIVER_NO_CYTHON=1 pip install -r test-datastax-requirements.txt + # for newer versions this is redundant, but in older versions we need to + # install, e.g., the cassandra driver, and those versions don't specify + # the cassandra driver version in requirements files + CASS_DRIVER_NO_CYTHON=1 python setup.py develop + pip install "jinja2==2.8.1;python_version<'3.6'" "sphinx>=1.3,<2" geomet + # build extensions like libev + CASS_DRIVER_NO_CYTHON=1 python setup.py build_ext --inplace --force +versions: + - name: '3.29' + ref: 434b1f52 + - name: '3.28' + ref: 4325afb6 + - name: '3.27' + ref: 910f0282 + - name: '3.26' + ref: f1e9126 + - name: '3.25' + ref: a83c36a5 + - name: '3.24' + ref: 21cac12b + - name: '3.23' + ref: a40a2af7 + - name: '3.22' + ref: 1ccd5b99 + - name: '3.21' + ref: 5589d96b + - name: '3.20' + ref: d30d166f + - name: '3.19' + ref: ac2471f9 + - name: '3.18' + ref: ec36b957 + - name: '3.17' + ref: 38e359e1 + - name: '3.16' + ref: '3.16.0' + - name: '3.15' + ref: '2ce0bd97' + - name: '3.14' + ref: '9af8bd19' + - name: '3.13' + ref: '3.13.0' + - name: '3.12' + ref: '43b9c995' + - name: '3.11' + ref: '3.11.0' + - name: '3.10' + ref: 64572368 + - name: 3.9 + ref: 3.9-doc + - name: 3.8 + ref: 3.8-doc + - name: 3.7 + ref: 3.7-doc + - name: 3.6 + ref: 3.6-doc + - name: 3.5 + ref: 3.5-doc +redirects: + - \A\/(.*)/\Z: /\1.html +rewrites: + - search: http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH + replace: https://docs.datastax.com/en/dse/6.7/cql/cql/cql_reference/cql_commands/cqlBatch.html + - search: http://www.datastax.com/documentation/cql/3.1/ + replace: https://docs.datastax.com/en/archived/cql/3.1/ + - search: 'https://community.datastax.com' + replace: 'https://www.datastax.com/dev/community' + - search: 'https://docs.datastax.com/en/astra/aws/doc/index.html' + replace: 'https://docs.datastax.com/en/astra-serverless/docs/connect/drivers/connect-python.html' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#collections' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections' + - search: 'http://cassandra.apache.org/doc/cql3/CQL.html#batchStmt' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#timeuuid-functions' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/functions.html#token' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/types.html#collections' + - search: 'http://cassandra.apache.org/doc/cql3/CQL-3.0.html#batchStmt' + replace: 'https://cassandra.apache.org/doc/3.11/cassandra/cql/dml.html#batch_statement' +checks: + external_links: + exclude: + - 'https://twitter.com/dsJavaDriver' + - 'https://twitter.com/datastaxeng' + - 'https://twitter.com/datastax' + - 'https://projectreactor.io' + - 'https://docs.datastax.com/en/drivers/java/4.[0-9]+/com/datastax/oss/driver/internal/' + - 'http://www.planetcassandra.org/blog/user-defined-functions-in-cassandra-3-0/' + - 'http://www.planetcassandra.org/making-the-change-from-thrift-to-cql/' + - 'https://academy.datastax.com/slack' + - 'https://community.datastax.com/index.html' + - 'https://micrometer.io/docs' + - 'http://datastax.github.io/java-driver/features/shaded_jar/' + - 'http://aka.ms/vcpython27' + internal_links: + exclude: + - 'netty_pipeline/' + - '../core/' + - '%5Bguava%20eviction%5D' diff --git a/docs/.nav b/docs/.nav new file mode 100644 index 0000000000..79f3029073 --- /dev/null +++ b/docs/.nav @@ -0,0 +1,21 @@ +installation +getting_started +execution_profiles +lwt +object_mapper +performance +query_paging +security +upgrading +user_defined_types +dates_and_times +cloud +column_encryption +geo_types +graph +classic_graph +graph_fluent +CHANGELOG +faq +api + diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst new file mode 100644 index 0000000000..592a2c0efa --- /dev/null +++ b/docs/CHANGELOG.rst @@ -0,0 +1,5 @@ +********* +CHANGELOG +********* + +.. include:: ../CHANGELOG.rst diff --git a/docs/api/cassandra.rst b/docs/api/cassandra.rst index 5628099a92..d46aae56cb 100644 --- a/docs/api/cassandra.rst +++ b/docs/api/cassandra.rst @@ -14,6 +14,9 @@ .. autoclass:: ConsistencyLevel :members: +.. autoclass:: ProtocolVersion + :members: + .. autoclass:: UserFunctionDescriptor :members: :inherited-members: @@ -22,6 +25,12 @@ :members: :inherited-members: +.. autoexception:: DriverException() + :members: + +.. autoexception:: RequestExecutionException() + :members: + .. autoexception:: Unavailable() :members: @@ -34,6 +43,9 @@ .. autoexception:: WriteTimeout() :members: +.. autoexception:: CoordinationFailure() + :members: + .. autoexception:: ReadFailure() :members: @@ -43,6 +55,12 @@ .. autoexception:: FunctionFailure() :members: +.. autoexception:: RequestValidationException() + :members: + +.. autoexception:: ConfigurationException() + :members: + .. autoexception:: AlreadyExists() :members: diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 9c546d3be8..a9a9d378a4 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -5,14 +5,14 @@ .. autoclass:: Cluster ([contact_points=('127.0.0.1',)][, port=9042][, executor_threads=2], **attr_kwargs) - Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. + .. autoattribute:: contact_points + + .. autoattribute:: port .. autoattribute:: cql_version .. autoattribute:: protocol_version - .. autoattribute:: port - .. autoattribute:: compression .. autoattribute:: auth_provider @@ -22,16 +22,17 @@ .. autoattribute:: reconnection_policy .. autoattribute:: default_retry_policy + :annotation: = .. autoattribute:: conviction_policy_factory - .. autoattribute:: connection_class + .. autoattribute:: address_translator .. autoattribute:: metrics_enabled .. autoattribute:: metrics - .. autoattribute:: metadata + .. autoattribute:: ssl_context .. autoattribute:: ssl_options @@ -39,16 +40,40 @@ .. autoattribute:: max_schema_agreement_wait + .. autoattribute:: metadata + + .. autoattribute:: connection_class + .. autoattribute:: control_connection_timeout .. autoattribute:: idle_heartbeat_interval + .. autoattribute:: idle_heartbeat_timeout + .. autoattribute:: schema_event_refresh_window .. autoattribute:: topology_event_refresh_window + .. autoattribute:: status_event_refresh_window + + .. autoattribute:: prepare_on_all_hosts + + .. autoattribute:: reprepare_on_up + .. autoattribute:: connect_timeout + .. autoattribute:: schema_metadata_enabled + :annotation: = True + + .. autoattribute:: token_metadata_enabled + :annotation: = True + + .. autoattribute:: timestamp_generator + + .. autoattribute:: endpoint_factory + + .. autoattribute:: cloud + .. automethod:: connect .. automethod:: shutdown @@ -59,6 +84,8 @@ .. automethod:: unregister_listener + .. automethod:: add_execution_profile + .. automethod:: set_max_requests_per_connection .. automethod:: get_max_requests_per_connection @@ -75,6 +102,8 @@ .. automethod:: set_max_connections_per_host + .. automethod:: get_control_connection_host + .. automethod:: refresh_schema_metadata .. automethod:: refresh_keyspace_metadata @@ -91,29 +120,62 @@ .. automethod:: set_meta_refresh_enabled +.. autoclass:: ExecutionProfile (load_balancing_policy=, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=, speculative_execution_policy=None) + :members: + :exclude-members: consistency_level + + .. autoattribute:: consistency_level + :annotation: = LOCAL_ONE + +.. autoclass:: GraphExecutionProfile (load_balancing_policy=_NOT_SET, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=30.0, row_factory=None, graph_options=None, continuous_paging_options=_NOT_SET) + :members: + +.. autoclass:: GraphAnalyticsExecutionProfile (load_balancing_policy=None, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=3600. * 24. * 7., row_factory=None, graph_options=None) + :members: + +.. autodata:: EXEC_PROFILE_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + :annotation: + +.. autodata:: EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT + :annotation: .. autoclass:: Session () .. autoattribute:: default_timeout + :annotation: = 10.0 .. autoattribute:: default_consistency_level :annotation: = LOCAL_ONE .. autoattribute:: default_serial_consistency_level + :annotation: = None .. autoattribute:: row_factory + :annotation: = .. autoattribute:: default_fetch_size .. autoattribute:: use_client_timestamp + .. autoattribute:: timestamp_generator + .. autoattribute:: encoder .. autoattribute:: client_protocol_handler - .. automethod:: execute(statement[, parameters][, timeout][, trace][, custom_payload]) + .. automethod:: execute(statement[, parameters][, timeout][, trace][, custom_payload][, paging_state][, host][, execute_as]) + + .. automethod:: execute_async(statement[, parameters][, trace][, custom_payload][, paging_state][, host][, execute_as]) + + .. automethod:: execute_graph(statement[, parameters][, trace][, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT][, execute_as]) - .. automethod:: execute_async(statement[, parameters][, trace][, custom_payload]) + .. automethod:: execute_graph_async(statement[, parameters][, trace][, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT][, execute_as]) .. automethod:: prepare(statement) @@ -121,6 +183,14 @@ .. automethod:: set_keyspace(keyspace) + .. automethod:: get_execution_profile + + .. automethod:: execution_profile_clone_update + + .. automethod:: add_request_init_listener + + .. automethod:: remove_request_init_listener + .. autoclass:: ResponseFuture () .. autoattribute:: query @@ -145,7 +215,7 @@ .. automethod:: add_errback(fn, *args, **kwargs) - .. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None) + .. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_kwargs=None) .. autoclass:: ResultSet () :members: diff --git a/docs/api/cassandra/connection.rst b/docs/api/cassandra/connection.rst index 3e9851b1a3..32cca590c0 100644 --- a/docs/api/cassandra/connection.rst +++ b/docs/api/cassandra/connection.rst @@ -7,3 +7,15 @@ .. autoexception:: ConnectionShutdown () .. autoexception:: ConnectionBusy () .. autoexception:: ProtocolError () + +.. autoclass:: EndPoint + :members: + +.. autoclass:: EndPointFactory + :members: + +.. autoclass:: SniEndPoint + +.. autoclass:: SniEndPointFactory + +.. autoclass:: UnixSocketEndPoint diff --git a/docs/api/cassandra/cqlengine/columns.rst b/docs/api/cassandra/cqlengine/columns.rst index 670633f73a..d44be8adb8 100644 --- a/docs/api/cassandra/cqlengine/columns.rst +++ b/docs/api/cassandra/cqlengine/columns.rst @@ -6,12 +6,12 @@ Columns ------- - Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. - For a model to be valid it needs at least one primary key column and one non-primary key column. +Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. +For a model to be valid it needs at least one primary key column and one non-primary key column. - Just as in CQL, the order you define your columns in is important, and is the same order they are defined in on a model's corresponding table. +Just as in CQL, the order you define your columns in is important, and is the same order they are defined in on a model's corresponding table. - Each column on your model definitions needs to be an instance of a Column class. +Each column on your model definitions needs to be an instance of a Column class. .. autoclass:: Column(**kwargs) @@ -21,6 +21,8 @@ Columns .. autoattribute:: index + .. autoattribute:: custom_index + .. autoattribute:: db_field .. autoattribute:: default diff --git a/docs/api/cassandra/cqlengine/connection.rst b/docs/api/cassandra/cqlengine/connection.rst index 184a6026cb..0f584fcca2 100644 --- a/docs/api/cassandra/cqlengine/connection.rst +++ b/docs/api/cassandra/cqlengine/connection.rst @@ -8,3 +8,9 @@ .. autofunction:: set_session .. autofunction:: setup + +.. autofunction:: register_connection + +.. autofunction:: unregister_connection + +.. autofunction:: set_default_connection diff --git a/docs/api/cassandra/cqlengine/models.rst b/docs/api/cassandra/cqlengine/models.rst index c4529b3988..ee689a2b48 100644 --- a/docs/api/cassandra/cqlengine/models.rst +++ b/docs/api/cassandra/cqlengine/models.rst @@ -32,8 +32,12 @@ Model .. autoattribute:: __keyspace__ - .. _ttl-change: - .. autoattribute:: __default_ttl__ + .. autoattribute:: __connection__ + + .. attribute:: __default_ttl__ + :annotation: = None + + Will be deprecated in release 4.0. You can set the default ttl by configuring the table ``__options__``. See :ref:`ttl-change` for more details. .. autoattribute:: __discriminator_value__ @@ -60,7 +64,6 @@ Model __options__ = {'compaction': {'class': 'LeveledCompactionStrategy', 'sstable_size_in_mb': '64', 'tombstone_threshold': '.2'}, - 'read_repair_chance': '0.5', 'comment': 'User data stored here'} user_id = columns.UUID(primary_key=True) @@ -79,6 +82,8 @@ Model 'tombstone_compaction_interval': '86400'}, 'gc_grace_seconds': '0'} + .. autoattribute:: __compute_routing_key__ + The base methods allow creating, storing, and querying modeled objects. @@ -98,7 +103,7 @@ Model TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111') except LWTException as e: # handle failure case - print e.existing # dict containing LWT result fields + print(e.existing # dict containing LWT result fields) This method is supported on Cassandra 2.0 or later. @@ -139,7 +144,7 @@ Model t.iff(count=5).update('other text') except LWTException as e: # handle failure case - print e.existing # existing object + print(e.existing # existing object) .. automethod:: get @@ -165,6 +170,10 @@ Model Sets the ttl values to run instance updates and inserts queries with. + .. method:: using(connection=None) + + Change the context on the fly of the model instance (keyspace, connection) + .. automethod:: column_family_name Models also support dict-like access: diff --git a/docs/api/cassandra/cqlengine/query.rst b/docs/api/cassandra/cqlengine/query.rst index ad5489f207..ce8f764b6b 100644 --- a/docs/api/cassandra/cqlengine/query.rst +++ b/docs/api/cassandra/cqlengine/query.rst @@ -42,14 +42,28 @@ The methods here are used to filter, order, and constrain results. .. automethod:: allow_filtering + .. automethod:: only + + .. automethod:: defer + .. automethod:: timestamp .. automethod:: ttl + .. automethod:: using + .. _blind_updates: .. automethod:: update +.. autoclass:: BatchQuery + :members: + + .. automethod:: add_query + .. automethod:: execute + +.. autoclass:: ContextQuery + .. autoclass:: DoesNotExist .. autoclass:: MultipleObjectsReturned diff --git a/docs/api/cassandra/datastax/graph/fluent/index.rst b/docs/api/cassandra/datastax/graph/fluent/index.rst new file mode 100644 index 0000000000..5547e0fdd7 --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/index.rst @@ -0,0 +1,24 @@ +:mod:`cassandra.datastax.graph.fluent` +====================================== + +.. module:: cassandra.datastax.graph.fluent + +.. autoclass:: DseGraph + + .. autoattribute:: DSE_GRAPH_QUERY_LANGUAGE + + .. automethod:: create_execution_profile + + .. automethod:: query_from_traversal + + .. automethod:: traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, traversal_class=None) + + .. automethod:: batch(session=None, execution_profile=None) + +.. autoclass:: DSESessionRemoteGraphConnection(session[, graph_name, execution_profile]) + +.. autoclass:: BaseGraphRowFactory + +.. autoclass:: graph_traversal_row_factory + +.. autoclass:: graph_traversal_dse_object_row_factory diff --git a/docs/api/cassandra/datastax/graph/fluent/predicates.rst b/docs/api/cassandra/datastax/graph/fluent/predicates.rst new file mode 100644 index 0000000000..f6e86f6451 --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/predicates.rst @@ -0,0 +1,14 @@ +:mod:`cassandra.datastax.graph.fluent.predicates` +================================================= + +.. module:: cassandra.datastax.graph.fluent.predicates + + +.. autoclass:: Search + :members: + +.. autoclass:: CqlCollection + :members: + +.. autoclass:: Geo + :members: diff --git a/docs/api/cassandra/datastax/graph/fluent/query.rst b/docs/api/cassandra/datastax/graph/fluent/query.rst new file mode 100644 index 0000000000..3dd859f96e --- /dev/null +++ b/docs/api/cassandra/datastax/graph/fluent/query.rst @@ -0,0 +1,8 @@ +:mod:`cassandra.datastax.graph.fluent.query` +============================================ + +.. module:: cassandra.datastax.graph.fluent.query + + +.. autoclass:: TraversalBatch + :members: diff --git a/docs/api/cassandra/datastax/graph/index.rst b/docs/api/cassandra/datastax/graph/index.rst new file mode 100644 index 0000000000..dafd5f65fd --- /dev/null +++ b/docs/api/cassandra/datastax/graph/index.rst @@ -0,0 +1,121 @@ +``cassandra.datastax.graph`` - Graph Statements, Options, and Row Factories +=========================================================================== + +.. _api-datastax-graph: + +.. module:: cassandra.datastax.graph + +.. autofunction:: single_object_row_factory + +.. autofunction:: graph_result_row_factory + +.. autofunction:: graph_object_row_factory + +.. autofunction:: graph_graphson2_row_factory + +.. autofunction:: graph_graphson3_row_factory + +.. function:: to_int(value) + + Wraps a value to be explicitly serialized as a graphson Int. + +.. function:: to_bigint(value) + + Wraps a value to be explicitly serialized as a graphson Bigint. + +.. function:: to_smallint(value) + + Wraps a value to be explicitly serialized as a graphson Smallint. + +.. function:: to_float(value) + + Wraps a value to be explicitly serialized as a graphson Float. + +.. function:: to_double(value) + + Wraps a value to be explicitly serialized as a graphson Double. + +.. autoclass:: GraphProtocol + :members: + +.. autoclass:: GraphOptions + + .. autoattribute:: graph_name + + .. autoattribute:: graph_source + + .. autoattribute:: graph_language + + .. autoattribute:: graph_read_consistency_level + + .. autoattribute:: graph_write_consistency_level + + .. autoattribute:: is_default_source + + .. autoattribute:: is_analytics_source + + .. autoattribute:: is_graph_source + + .. automethod:: set_source_default + + .. automethod:: set_source_analytics + + .. automethod:: set_source_graph + + +.. autoclass:: SimpleGraphStatement + :members: + +.. autoclass:: Result + :members: + +.. autoclass:: Vertex + :members: + +.. autoclass:: VertexProperty + :members: + +.. autoclass:: Edge + :members: + +.. autoclass:: Path + :members: + +.. autoclass:: T + :members: + +.. autoclass:: GraphSON1Serializer + :members: + +.. autoclass:: GraphSON1Deserializer + + .. automethod:: deserialize_date + + .. automethod:: deserialize_timestamp + + .. automethod:: deserialize_time + + .. automethod:: deserialize_duration + + .. automethod:: deserialize_int + + .. automethod:: deserialize_bigint + + .. automethod:: deserialize_double + + .. automethod:: deserialize_float + + .. automethod:: deserialize_uuid + + .. automethod:: deserialize_blob + + .. automethod:: deserialize_decimal + + .. automethod:: deserialize_point + + .. automethod:: deserialize_linestring + + .. automethod:: deserialize_polygon + +.. autoclass:: GraphSON2Reader + :members: diff --git a/docs/api/cassandra/graph.rst b/docs/api/cassandra/graph.rst new file mode 100644 index 0000000000..43ddd3086c --- /dev/null +++ b/docs/api/cassandra/graph.rst @@ -0,0 +1,121 @@ +``cassandra.graph`` - Graph Statements, Options, and Row Factories +================================================================== + +.. note:: This module is only for backward compatibility for dse-driver users. Consider using :ref:`cassandra.datastax.graph `. + +.. module:: cassandra.graph + +.. autofunction:: single_object_row_factory + +.. autofunction:: graph_result_row_factory + +.. autofunction:: graph_object_row_factory + +.. autofunction:: graph_graphson2_row_factory + +.. autofunction:: graph_graphson3_row_factory + +.. function:: to_int(value) + + Wraps a value to be explicitly serialized as a graphson Int. + +.. function:: to_bigint(value) + + Wraps a value to be explicitly serialized as a graphson Bigint. + +.. function:: to_smallint(value) + + Wraps a value to be explicitly serialized as a graphson Smallint. + +.. function:: to_float(value) + + Wraps a value to be explicitly serialized as a graphson Float. + +.. function:: to_double(value) + + Wraps a value to be explicitly serialized as a graphson Double. + +.. autoclass:: GraphProtocol + :members: + +.. autoclass:: GraphOptions + + .. autoattribute:: graph_name + + .. autoattribute:: graph_source + + .. autoattribute:: graph_language + + .. autoattribute:: graph_read_consistency_level + + .. autoattribute:: graph_write_consistency_level + + .. autoattribute:: is_default_source + + .. autoattribute:: is_analytics_source + + .. autoattribute:: is_graph_source + + .. automethod:: set_source_default + + .. automethod:: set_source_analytics + + .. automethod:: set_source_graph + + +.. autoclass:: SimpleGraphStatement + :members: + +.. autoclass:: Result + :members: + +.. autoclass:: Vertex + :members: + +.. autoclass:: VertexProperty + :members: + +.. autoclass:: Edge + :members: + +.. autoclass:: Path + :members: + +.. autoclass:: GraphSON1Serializer + :members: + +.. autoclass:: GraphSON1Deserializer + + .. automethod:: deserialize_date + + .. automethod:: deserialize_timestamp + + .. automethod:: deserialize_time + + .. automethod:: deserialize_duration + + .. automethod:: deserialize_int + + .. automethod:: deserialize_bigint + + .. automethod:: deserialize_double + + .. automethod:: deserialize_float + + .. automethod:: deserialize_uuid + + .. automethod:: deserialize_blob + + .. automethod:: deserialize_decimal + + .. automethod:: deserialize_point + + .. automethod:: deserialize_linestring + + .. automethod:: deserialize_polygon + +.. autoclass:: GraphSON2Reader + :members: + +.. autoclass:: GraphSON3Reader + :members: diff --git a/docs/api/cassandra/io/asyncioreactor.rst b/docs/api/cassandra/io/asyncioreactor.rst new file mode 100644 index 0000000000..38ae63ca7f --- /dev/null +++ b/docs/api/cassandra/io/asyncioreactor.rst @@ -0,0 +1,7 @@ +``cassandra.io.asyncioreactor`` - ``asyncio`` Event Loop +===================================================================== + +.. module:: cassandra.io.asyncioreactor + +.. autoclass:: AsyncioConnection + :members: diff --git a/docs/api/cassandra/metadata.rst b/docs/api/cassandra/metadata.rst index d797f739de..91fe39fd99 100644 --- a/docs/api/cassandra/metadata.rst +++ b/docs/api/cassandra/metadata.rst @@ -14,7 +14,7 @@ .. autoclass:: Metadata () :members: - :exclude-members: rebuild_schema, rebuild_token_map, add_host, remove_host, get_host + :exclude-members: rebuild_schema, rebuild_token_map, add_host, remove_host Schemas ------- @@ -34,6 +34,12 @@ Schemas .. autoclass:: TableMetadata () :members: +.. autoclass:: TableMetadataV3 () + :members: + +.. autoclass:: TableMetadataDSE68 () + :members: + .. autoclass:: ColumnMetadata () :members: @@ -43,6 +49,12 @@ Schemas .. autoclass:: MaterializedViewMetadata () :members: +.. autoclass:: VertexMetadata () + :members: + +.. autoclass:: EdgeMetadata () + :members: + Tokens and Ring Topology ------------------------ @@ -64,6 +76,10 @@ Tokens and Ring Topology .. autoclass:: ReplicationStrategy :members: +.. autoclass:: ReplicationFactor + :members: + :exclude-members: create + .. autoclass:: SimpleStrategy :members: @@ -72,3 +88,5 @@ Tokens and Ring Topology .. autoclass:: LocalStrategy :members: + +.. autofunction:: group_keys_by_replica diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 44346c4bd4..387b19ed95 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -24,6 +24,32 @@ Load Balancing .. autoclass:: TokenAwarePolicy :members: +.. autoclass:: HostFilterPolicy + + .. we document these methods manually so we can specify a param to predicate + + .. automethod:: predicate(host) + .. automethod:: distance + .. automethod:: make_query_plan + +.. autoclass:: DefaultLoadBalancingPolicy + :members: + +.. autoclass:: DSELoadBalancingPolicy + :members: + +Translating Server Node Addresses +--------------------------------- + +.. autoclass:: AddressTranslator + :members: + +.. autoclass:: IdentityTranslator + :members: + +.. autoclass:: EC2MultiRegionTranslator + :members: + Marking Hosts Up or Down ------------------------ @@ -59,3 +85,12 @@ Retrying Failed Operations .. autoclass:: DowngradingConsistencyRetryPolicy :members: + +Retrying Idempotent Operations +------------------------------ + +.. autoclass:: SpeculativeExecutionPolicy + :members: + +.. autoclass:: ConstantSpeculativeExecutionPolicy + :members: diff --git a/docs/api/cassandra/protocol.rst b/docs/api/cassandra/protocol.rst index f3a3cd5ab5..f615ab1a70 100644 --- a/docs/api/cassandra/protocol.rst +++ b/docs/api/cassandra/protocol.rst @@ -45,11 +45,11 @@ and ``NumpyProtocolHandler``. They can be used as follows: These protocol handlers comprise different parsers, and return results as described below: - - ProtocolHandler: this default implementation is a drop-in replacement for the pure-Python version. - The rows are all parsed upfront, before results are returned. +- ProtocolHandler: this default implementation is a drop-in replacement for the pure-Python version. + The rows are all parsed upfront, before results are returned. - - LazyProtocolHandler: near drop-in replacement for the above, except that it returns an iterator over rows, - lazily decoded into the default row format (this is more efficient since all decoded results are not materialized at once) +- LazyProtocolHandler: near drop-in replacement for the above, except that it returns an iterator over rows, + lazily decoded into the default row format (this is more efficient since all decoded results are not materialized at once) - - NumpyProtocolHander: deserializes results directly into NumPy arrays. This facilitates efficient integration with - analysis toolkits such as Pandas. +- NumpyProtocolHander: deserializes results directly into NumPy arrays. This facilitates efficient integration with + analysis toolkits such as Pandas. diff --git a/docs/api/cassandra/query.rst b/docs/api/cassandra/query.rst index 55c56cf168..fcd79739b9 100644 --- a/docs/api/cassandra/query.rst +++ b/docs/api/cassandra/query.rst @@ -11,9 +11,6 @@ .. autofunction:: ordered_dict_factory -.. autoclass:: Statement - :members: - .. autoclass:: SimpleStatement :members: @@ -23,6 +20,9 @@ .. autoclass:: BoundStatement :members: +.. autoclass:: Statement () + :members: + .. autodata:: UNSET_VALUE :annotation: diff --git a/docs/api/cassandra/timestamps.rst b/docs/api/cassandra/timestamps.rst new file mode 100644 index 0000000000..7c7f534aea --- /dev/null +++ b/docs/api/cassandra/timestamps.rst @@ -0,0 +1,14 @@ +``cassandra.timestamps`` - Timestamp Generation +============================================= + +.. module:: cassandra.timestamps + +.. autoclass:: MonotonicTimestampGenerator (warn_on_drift=True, warning_threshold=0, warning_interval=0) + + .. autoattribute:: warn_on_drift + + .. autoattribute:: warning_threshold + + .. autoattribute:: warning_interval + + .. automethod:: _next_timestamp diff --git a/docs/api/index.rst b/docs/api/index.rst index 340a5e0235..9e778d508c 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -10,6 +10,7 @@ Core Driver cassandra/cluster cassandra/policies cassandra/auth + cassandra/graph cassandra/metadata cassandra/metrics cassandra/query @@ -20,6 +21,8 @@ Core Driver cassandra/concurrent cassandra/connection cassandra/util + cassandra/timestamps + cassandra/io/asyncioreactor cassandra/io/asyncorereactor cassandra/io/eventletreactor cassandra/io/libevreactor @@ -39,3 +42,13 @@ Object Mapper cassandra/cqlengine/connection cassandra/cqlengine/management cassandra/cqlengine/usertype + +DataStax Graph +-------------- +.. toctree:: + :maxdepth: 1 + + cassandra/datastax/graph/index + cassandra/datastax/graph/fluent/index + cassandra/datastax/graph/fluent/query + cassandra/datastax/graph/fluent/predicates diff --git a/docs/classic_graph.rst b/docs/classic_graph.rst new file mode 100644 index 0000000000..ef68c86359 --- /dev/null +++ b/docs/classic_graph.rst @@ -0,0 +1,299 @@ +DataStax Classic Graph Queries +============================== + +Getting Started +~~~~~~~~~~~~~~~ + +First, we need to create a graph in the system. To access the system API, we +use the system execution profile :: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + + cluster = Cluster() + session = cluster.connect() + + graph_name = 'movies' + session.execute_graph("system.graph(name).ifNotExists().engine(Classic).create()", {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + + +To execute requests on our newly created graph, we need to setup an execution +profile. Additionally, we also need to set the schema_mode to `development` +for the schema creation:: + + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions + + graph_name = 'movies' + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name=graph_name)) + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + + session.execute_graph("schema.config().option('graph.schema_mode').set('development')") + + +We are ready to configure our graph schema. We will create a simple one for movies:: + + # properties are used to define a vertex + properties = """ + schema.propertyKey("genreId").Text().create(); + schema.propertyKey("personId").Text().create(); + schema.propertyKey("movieId").Text().create(); + schema.propertyKey("name").Text().create(); + schema.propertyKey("title").Text().create(); + schema.propertyKey("year").Int().create(); + schema.propertyKey("country").Text().create(); + """ + + session.execute_graph(properties) # we can execute multiple statements in a single request + + # A Vertex represents a "thing" in the world. + vertices = """ + schema.vertexLabel("genre").properties("genreId","name").create(); + schema.vertexLabel("person").properties("personId","name").create(); + schema.vertexLabel("movie").properties("movieId","title","year","country").create(); + """ + + session.execute_graph(vertices) + + # An edge represents a relationship between two vertices + edges = """ + schema.edgeLabel("belongsTo").single().connection("movie","genre").create(); + schema.edgeLabel("actor").connection("movie","person").create(); + """ + + session.execute_graph(edges) + + # Indexes to execute graph requests efficiently + indexes = """ + schema.vertexLabel("genre").index("genresById").materialized().by("genreId").add(); + schema.vertexLabel("genre").index("genresByName").materialized().by("name").add(); + schema.vertexLabel("person").index("personsById").materialized().by("personId").add(); + schema.vertexLabel("person").index("personsByName").materialized().by("name").add(); + schema.vertexLabel("movie").index("moviesById").materialized().by("movieId").add(); + schema.vertexLabel("movie").index("moviesByTitle").materialized().by("title").add(); + schema.vertexLabel("movie").index("moviesByYear").secondary().by("year").add(); + """ + +Next, we'll add some data:: + + session.execute_graph(""" + g.addV('genre').property('genreId', 1).property('name', 'Action').next(); + g.addV('genre').property('genreId', 2).property('name', 'Drama').next(); + g.addV('genre').property('genreId', 3).property('name', 'Comedy').next(); + g.addV('genre').property('genreId', 4).property('name', 'Horror').next(); + """) + + session.execute_graph(""" + g.addV('person').property('personId', 1).property('name', 'Mark Wahlberg').next(); + g.addV('person').property('personId', 2).property('name', 'Leonardo DiCaprio').next(); + g.addV('person').property('personId', 3).property('name', 'Iggy Pop').next(); + """) + + session.execute_graph(""" + g.addV('movie').property('movieId', 1).property('title', 'The Happening'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 2).property('title', 'The Italian Job'). + property('year', 2003).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 3).property('title', 'Revolutionary Road'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 4).property('title', 'The Man in the Iron Mask'). + property('year', 1998).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 5).property('title', 'Dead Man'). + property('year', 1995).property('country', 'United States').next(); + """) + +Now that our genre, actor and movie vertices are added, we'll create the relationships (edges) between them:: + + session.execute_graph(""" + genre_horror = g.V().hasLabel('genre').has('name', 'Horror').next(); + genre_drama = g.V().hasLabel('genre').has('name', 'Drama').next(); + genre_action = g.V().hasLabel('genre').has('name', 'Action').next(); + + leo = g.V().hasLabel('person').has('name', 'Leonardo DiCaprio').next(); + mark = g.V().hasLabel('person').has('name', 'Mark Wahlberg').next(); + iggy = g.V().hasLabel('person').has('name', 'Iggy Pop').next(); + + the_happening = g.V().hasLabel('movie').has('title', 'The Happening').next(); + the_italian_job = g.V().hasLabel('movie').has('title', 'The Italian Job').next(); + rev_road = g.V().hasLabel('movie').has('title', 'Revolutionary Road').next(); + man_mask = g.V().hasLabel('movie').has('title', 'The Man in the Iron Mask').next(); + dead_man = g.V().hasLabel('movie').has('title', 'Dead Man').next(); + + the_happening.addEdge('belongsTo', genre_horror); + the_italian_job.addEdge('belongsTo', genre_action); + rev_road.addEdge('belongsTo', genre_drama); + man_mask.addEdge('belongsTo', genre_drama); + man_mask.addEdge('belongsTo', genre_action); + dead_man.addEdge('belongsTo', genre_drama); + + the_happening.addEdge('actor', mark); + the_italian_job.addEdge('actor', mark); + rev_road.addEdge('actor', leo); + man_mask.addEdge('actor', leo); + dead_man.addEdge('actor', iggy); + """) + +We are all set. You can now query your graph. Here are some examples:: + + # Find all movies of the genre Drama + for r in session.execute_graph(""" + g.V().has('genre', 'name', 'Drama').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of the same genre than the movie 'Dead Man' + for r in session.execute_graph(""" + g.V().has('movie', 'title', 'Dead Man').out('belongsTo').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of Mark Wahlberg + for r in session.execute_graph(""" + g.V().has('person', 'name', 'Mark Wahlberg').in('actor').valueMap();"""): + print(r) + +To see a more graph examples, see `DataStax Graph Examples `_. + +Graph Types +~~~~~~~~~~~ + +Here are the supported graph types with their python representations: + +========== ================ +DSE Graph Python +========== ================ +boolean bool +bigint long, int (PY3) +int int +smallint int +varint int +float float +double double +uuid uuid.UUID +Decimal Decimal +inet str +timestamp datetime.datetime +date datetime.date +time datetime.time +duration datetime.timedelta +point Point +linestring LineString +polygon Polygon +blob bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +========== ================ + +Graph Row Factory +~~~~~~~~~~~~~~~~~ + +By default (with :class:`.GraphExecutionProfile.row_factory` set to :func:`.graph.graph_object_row_factory`), known graph result +types are unpacked and returned as specialized types (:class:`.Vertex`, :class:`.Edge`). If the result is not one of these +types, a :class:`.graph.Result` is returned, containing the graph result parsed from JSON and removed from its outer dict. +The class has some accessor convenience methods for accessing top-level properties by name (`type`, `properties` above), +or lists by index:: + + # dicts with `__getattr__` or `__getitem__` + result = session.execute_graph("[[key_str: 'value', key_int: 3]]", execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0] # Using system exec just because there is no graph defined + result # dse.graph.Result({u'key_str': u'value', u'key_int': 3}) + result.value # {u'key_int': 3, u'key_str': u'value'} (dict) + result.key_str # u'value' + result.key_int # 3 + result['key_str'] # u'value' + result['key_int'] # 3 + + # lists with `__getitem__` + result = session.execute_graph('[[0, 1, 2]]', execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0] + result # dse.graph.Result([0, 1, 2]) + result.value # [0, 1, 2] (list) + result[1] # 1 (list[1]) + +You can use a different row factory by setting :attr:`.Session.default_graph_row_factory` or passing it to +:meth:`.Session.execute_graph`. For example, :func:`.graph.single_object_row_factory` returns the JSON result string`, +unparsed. :func:`.graph.graph_result_row_factory` returns parsed, but unmodified results (such that all metadata is retained, +unlike :func:`.graph.graph_object_row_factory`, which sheds some as attributes and properties are unpacked). These results +also provide convenience methods for converting to known types (:meth:`~.Result.as_vertex`, :meth:`~.Result.as_edge`, :meth:`~.Result.as_path`). + +Vertex and Edge properties are never unpacked since their types are unknown. If you know your graph schema and want to +deserialize properties, use the :class:`.GraphSON1Deserializer`. It provides convenient methods to deserialize by types (e.g. +deserialize_date, deserialize_uuid, deserialize_polygon etc.) Example:: + + # ... + from cassandra.graph import GraphSON1Deserializer + + row = session.execute_graph("g.V().toList()")[0] + value = row.properties['my_property_key'][0].value # accessing the VertexProperty value + value = GraphSON1Deserializer.deserialize_timestamp(value) + + print(value) # 2017-06-26 08:27:05 + print(type(value)) # + + +Named Parameters +~~~~~~~~~~~~~~~~ + +Named parameters are passed in a dict to :meth:`.cluster.Session.execute_graph`:: + + result_set = session.execute_graph('[a, b]', {'a': 1, 'b': 2}, execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + [r.value for r in result_set] # [1, 2] + +All python types listed in `Graph Types`_ can be passed as named parameters and will be serialized +automatically to their graph representation: + +Example:: + + session.execute_graph(""" + g.addV('person'). + property('name', text_value). + property('age', integer_value). + property('birthday', timestamp_value). + property('house_yard', polygon_value).toList() + """, { + 'text_value': 'Mike Smith', + 'integer_value': 34, + 'timestamp_value': datetime.datetime(1967, 12, 30), + 'polygon_value': Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + }) + + +As with all Execution Profile parameters, graph options can be set in the cluster default (as shown in the first example) +or specified per execution:: + + ep = session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, + graph_options=GraphOptions(graph_name='something-else')) + session.execute_graph(statement, execution_profile=ep) + +Using GraphSON2 Protocol +~~~~~~~~~~~~~~~~~~~~~~~~ + +The default graph protocol used is GraphSON1. However GraphSON1 may +cause problems of type conversion happening during the serialization +of the query to the DSE Graph server, or the deserialization of the +responses back from a string Gremlin query. GraphSON2 offers better +support for the complex data types handled by DSE Graph. + +DSE >=5.0.4 now offers the possibility to use the GraphSON2 protocol +for graph queries. Enabling GraphSON2 can be done by `changing the +graph protocol of the execution profile` and `setting the graphson2 row factory`:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions, GraphProtocol, graph_graphson2_row_factory + + # Create a GraphSON2 execution profile + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name='types', + graph_protocol=GraphProtocol.GRAPHSON_2_0), + row_factory=graph_graphson2_row_factory) + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + session.execute_graph(...) + +Using GraphSON2, all properties will be automatically deserialized to +its Python representation. Note that it may bring significant +behavioral change at runtime. + +It is generally recommended to switch to GraphSON2 as it brings more +consistent support for complex data types in the Graph driver and will +be activated by default in the next major version (Python dse-driver +driver 3.0). diff --git a/docs/cloud.rst b/docs/cloud.rst new file mode 100644 index 0000000000..3230720ec9 --- /dev/null +++ b/docs/cloud.rst @@ -0,0 +1,105 @@ +Cloud +----- +Connecting +========== +To connect to a DataStax Astra cluster: + +1. Download the secure connect bundle from your Astra account. +2. Connect to your cluster with + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + auth_provider = PlainTextAuthProvider(username='user', password='pass') + cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider) + session = cluster.connect() + +Cloud Config Options +==================== + +use_default_tempdir ++++++++++++++++++++ + +The secure connect bundle needs to be extracted to load the certificates into the SSLContext. +By default, the zip location is used as the base dir for the extraction. In some environments, +the zip location file system is read-only (e.g Azure Function). With *use_default_tempdir* set to *True*, +the default temporary directory of the system will be used as base dir. + +.. code:: python + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + 'use_default_tempdir': True + } + ... + +connect_timeout ++++++++++++++++++++ + +As part of the process of connecting to Astra the Python driver will query a service to retrieve +current information about your cluster. You can control the connection timeout for this operation +using *connect_timeout*. If you observe errors in `read_metadata_info` you might consider increasing +this parameter. This timeout is specified in seconds. + +.. code:: python + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip', + 'connect_timeout': 120 + } + ... + +Astra Differences +================== +In most circumstances, the client code for interacting with an Astra cluster will be the same as interacting with any other Cassandra cluster. The exceptions being: + +* A cloud configuration must be passed to a :class:`~.Cluster` instance via the `cloud` attribute (as demonstrated above). +* An SSL connection will be established automatically. Manual SSL configuration is not allowed, and using `ssl_context` or `ssl_options` will result in an exception. +* A :class:`~.Cluster`'s `contact_points` attribute should not be used. The cloud config contains all of the necessary contact information. +* If a consistency level is not specified for an execution profile or query, then :attr:`.ConsistencyLevel.LOCAL_QUORUM` will be used as the default. + + +Limitations +=========== + +Event loops +^^^^^^^^^^^ +Evenlet isn't yet supported for python 3.7+ due to an `issue in Eventlet `_. + + +CqlEngine +========= + +When using the object mapper, you can configure cqlengine with :func:`~.cqlengine.connection.set_session`: + +.. code:: python + + from cassandra.cqlengine import connection + ... + + c = Cluster(cloud={'secure_connect_bundle':'/path/to/secure-connect-test.zip'}, + auth_provider=PlainTextAuthProvider('user', 'pass')) + s = c.connect('myastrakeyspace') + connection.set_session(s) + ... + +If you are using some third-party libraries (flask, django, etc.), you might not be able to change the +configuration mechanism. For this reason, the `hosts` argument of the default +:func:`~.cqlengine.connection.setup` function will be ignored if a `cloud` config is provided: + +.. code:: python + + from cassandra.cqlengine import connection + ... + + connection.setup( + None, # or anything else + "myastrakeyspace", cloud={ + 'secure_connect_bundle':'/path/to/secure-connect-test.zip' + }, + auth_provider=PlainTextAuthProvider('user', 'pass')) diff --git a/docs/column_encryption.rst b/docs/column_encryption.rst new file mode 100644 index 0000000000..ab67ef16d0 --- /dev/null +++ b/docs/column_encryption.rst @@ -0,0 +1,101 @@ +Column Encryption +================= + +Overview +-------- +Support for client-side encryption of data was added in version 3.27.0 of the Python driver. When using +this feature data will be encrypted on-the-fly according to a specified :class:`~.ColumnEncryptionPolicy` +instance. This policy is also used to decrypt data in returned rows. If a prepared statement is used +this decryption is transparent to the user; retrieved data will be decrypted and converted into the original +type (according to definitions in the encryption policy). Support for simple (i.e. non-prepared) queries is +also available, although in this case values must be manually encrypted and/or decrypted. The +:class:`~.ColumnEncryptionPolicy` instance provides methods to assist with these operations. + +Client-side encryption and decryption should work against all versions of Cassandra and DSE. It does not +utilize any server-side functionality to do its work. + +WARNING: Encryption format changes in 3.28.0 +------------------------------------------------ +Python driver 3.28.0 introduces a new encryption format for data written by :class:`~.AES256ColumnEncryptionPolicy`. +As a result, any encrypted data written by Python driver 3.27.0 will **NOT** be readable. +If you upgraded from 3.27.0, you should re-encrypt your data with 3.28.0. + +Configuration +------------- +Client-side encryption is enabled by creating an instance of a subclass of :class:`~.ColumnEncryptionPolicy` +and adding information about columns to be encrypted to it. This policy is then supplied to :class:`~.Cluster` +when it's created. + +.. code-block:: python + + import os + + from cassandra.policies import ColDesc + from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, AES256_KEY_SIZE_BYTES + + key = os.urandom(AES256_KEY_SIZE_BYTES) + cl_policy = AES256ColumnEncryptionPolicy() + col_desc = ColDesc('ks1','table1','column1') + cql_type = "int" + cl_policy.add_column(col_desc, key, cql_type) + cluster = Cluster(column_encryption_policy=cl_policy) + +:class:`~.AES256ColumnEncryptionPolicy` is a subclass of :class:`~.ColumnEncryptionPolicy` which provides +encryption and decryption via AES-256. This class is currently the only available column encryption policy +implementation, although users can certainly implement their own by subclassing :class:`~.ColumnEncryptionPolicy`. + +:class:`~.ColDesc` is a named tuple which uniquely identifies a column in a given keyspace and table. When we +have this tuple, the encryption key and the CQL type contained by this column we can add the column to the policy +using :func:`~.ColumnEncryptionPolicy.add_column`. Once we have added all column definitions to the policy we +pass it along to the cluster. + +The CQL type for the column only has meaning at the client; it is never sent to Cassandra. The encryption key +is also never sent to the server; all the server ever sees are random bytes reflecting the encrypted data. As a +result all columns containing client-side encrypted values should be declared with the CQL type "blob" at the +Cassandra server. + +Usage +----- + +Encryption +^^^^^^^^^^ +Client-side encryption shines most when used with prepared statements. A prepared statement is aware of information +about the columns in the query it was built from and we can use this information to transparently encrypt any +supplied parameters. For example, we can create a prepared statement to insert a value into column1 (as defined above) +by executing the following code after creating a :class:`~.Cluster` in the manner described above: + +.. code-block:: python + + session = cluster.connect() + prepared = session.prepare("insert into ks1.table1 (column1) values (?)") + session.execute(prepared, (1000,)) + +Our encryption policy will detect that "column1" is an encrypted column and take appropriate action. + +As mentioned above client-side encryption can also be used with simple queries, although such use cases are +certainly not transparent. :class:`~.ColumnEncryptionPolicy` provides a helper named +:func:`~.ColumnEncryptionPolicy.encode_and_encrypt` which will convert an input value into bytes using the +standard serialization methods employed by the driver. The result is then encrypted according to the configuration +of the policy. Using this approach the example above could be implemented along the lines of the following: + +.. code-block:: python + + session = cluster.connect() + session.execute("insert into ks1.table1 (column1) values (%s)",(cl_policy.encode_and_encrypt(col_desc, 1000),)) + +Decryption +^^^^^^^^^^ +Decryption of values returned from the server is always transparent. Whether we're executing a simple or prepared +statement encrypted columns will be decrypted automatically and made available via rows just like any other +result. + +Limitations +----------- +:class:`~.AES256ColumnEncryptionPolicy` uses the implementation of AES-256 provided by the +`cryptography `_ module. Any limitations of this module should be considered +when deploying client-side encryption. Note specifically that a Rust compiler is required for modern versions +of the cryptography package, although wheels exist for many common platforms. + +Client-side encryption has been implemented for both the default Cython and pure Python row processing logic. +This functionality has not yet been ported to the NumPy Cython implementation. During testing, +the NumPy processing works on Python 3.7 but fails for Python 3.8. diff --git a/docs/conf.py b/docs/conf.py index 167c7bd89b..4c0dfb58d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -43,7 +43,7 @@ # General information about the project. project = u'Cassandra Driver' -copyright = u'2013-2016 DataStax' +copyright = u'2013-2017 DataStax' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -136,7 +136,14 @@ #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +html_sidebars = { + '**': [ + 'about.html', + 'navigation.html', + 'relations.html', + 'searchbox.html' + ] +} # Additional templates that should be rendered to pages, maps page names to # template names. @@ -146,7 +153,7 @@ #html_domain_indices = True # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False @@ -216,5 +223,5 @@ # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'cassandra-driver', u'Cassandra Driver Documentation', - [u'Tyler Hobbs'], 1) + [u'DataStax'], 1) ] diff --git a/docs/cqlengine/batches.rst b/docs/cqlengine/batches.rst index a567e31c27..306e7d01a6 100644 --- a/docs/cqlengine/batches.rst +++ b/docs/cqlengine/batches.rst @@ -8,101 +8,101 @@ cqlengine supports batch queries using the BatchQuery class. Batch queries can b Batch Query General Use Pattern =============================== - You can only create, update, and delete rows with a batch query, attempting to read rows out of the database with a batch query will fail. +You can only create, update, and delete rows with a batch query, attempting to read rows out of the database with a batch query will fail. - .. code-block:: python +.. code-block:: python - from cqlengine import BatchQuery + from cassandra.cqlengine.query import BatchQuery - #using a context manager - with BatchQuery() as b: - now = datetime.now() - em1 = ExampleModel.batch(b).create(example_type=0, description="1", created_at=now) - em2 = ExampleModel.batch(b).create(example_type=0, description="2", created_at=now) - em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) - - # -- or -- - - #manually - b = BatchQuery() + #using a context manager + with BatchQuery() as b: now = datetime.now() em1 = ExampleModel.batch(b).create(example_type=0, description="1", created_at=now) em2 = ExampleModel.batch(b).create(example_type=0, description="2", created_at=now) em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) - b.execute() - # updating in a batch + # -- or -- + + #manually + b = BatchQuery() + now = datetime.now() + em1 = ExampleModel.batch(b).create(example_type=0, description="1", created_at=now) + em2 = ExampleModel.batch(b).create(example_type=0, description="2", created_at=now) + em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) + b.execute() + + # updating in a batch - b = BatchQuery() - em1.description = "new description" - em1.batch(b).save() - em2.description = "another new description" - em2.batch(b).save() - b.execute() + b = BatchQuery() + em1.description = "new description" + em1.batch(b).save() + em2.description = "another new description" + em2.batch(b).save() + b.execute() - # deleting in a batch - b = BatchQuery() - ExampleModel.objects(id=some_id).batch(b).delete() - ExampleModel.objects(id=some_id2).batch(b).delete() - b.execute() + # deleting in a batch + b = BatchQuery() + ExampleModel.objects(id=some_id).batch(b).delete() + ExampleModel.objects(id=some_id2).batch(b).delete() + b.execute() - Typically you will not want the block to execute if an exception occurs inside the `with` block. However, in the case that this is desirable, it's achievable by using the following syntax: +Typically you will not want the block to execute if an exception occurs inside the `with` block. However, in the case that this is desirable, it's achievable by using the following syntax: - .. code-block:: python +.. code-block:: python - with BatchQuery(execute_on_exception=True) as b: - LogEntry.batch(b).create(k=1, v=1) - mystery_function() # exception thrown in here - LogEntry.batch(b).create(k=1, v=2) # this code is never reached due to the exception, but anything leading up to here will execute in the batch. + with BatchQuery(execute_on_exception=True) as b: + LogEntry.batch(b).create(k=1, v=1) + mystery_function() # exception thrown in here + LogEntry.batch(b).create(k=1, v=2) # this code is never reached due to the exception, but anything leading up to here will execute in the batch. - If an exception is thrown somewhere in the block, any statements that have been added to the batch will still be executed. This is useful for some logging situations. +If an exception is thrown somewhere in the block, any statements that have been added to the batch will still be executed. This is useful for some logging situations. Batch Query Execution Callbacks =============================== - In order to allow secondary tasks to be chained to the end of batch, BatchQuery instances allow callbacks to be - registered with the batch, to be executed immediately after the batch executes. +In order to allow secondary tasks to be chained to the end of batch, BatchQuery instances allow callbacks to be +registered with the batch, to be executed immediately after the batch executes. - Multiple callbacks can be attached to same BatchQuery instance, they are executed in the same order that they - are added to the batch. +Multiple callbacks can be attached to same BatchQuery instance, they are executed in the same order that they +are added to the batch. - The callbacks attached to a given batch instance are executed only if the batch executes. If the batch is used as a - context manager and an exception is raised, the queued up callbacks will not be run. +The callbacks attached to a given batch instance are executed only if the batch executes. If the batch is used as a +context manager and an exception is raised, the queued up callbacks will not be run. - .. code-block:: python +.. code-block:: python - def my_callback(*args, **kwargs): - pass + def my_callback(*args, **kwargs): + pass - batch = BatchQuery() + batch = BatchQuery() - batch.add_callback(my_callback) - batch.add_callback(my_callback, 'positional arg', named_arg='named arg value') + batch.add_callback(my_callback) + batch.add_callback(my_callback, 'positional arg', named_arg='named arg value') - # if you need reference to the batch within the callback, - # just trap it in the arguments to be passed to the callback: - batch.add_callback(my_callback, cqlengine_batch=batch) + # if you need reference to the batch within the callback, + # just trap it in the arguments to be passed to the callback: + batch.add_callback(my_callback, cqlengine_batch=batch) - # once the batch executes... - batch.execute() + # once the batch executes... + batch.execute() - # the effect of the above scheduled callbacks will be similar to - my_callback() - my_callback('positional arg', named_arg='named arg value') - my_callback(cqlengine_batch=batch) + # the effect of the above scheduled callbacks will be similar to + my_callback() + my_callback('positional arg', named_arg='named arg value') + my_callback(cqlengine_batch=batch) - Failure in any of the callbacks does not affect the batch's execution, as the callbacks are started after the execution - of the batch is complete. +Failure in any of the callbacks does not affect the batch's execution, as the callbacks are started after the execution +of the batch is complete. Logged vs Unlogged Batches --------------------------- - By default, queries in cqlengine are LOGGED, which carries additional overhead from UNLOGGED. To explicitly state which batch type to use, simply: +By default, queries in cqlengine are LOGGED, which carries additional overhead from UNLOGGED. To explicitly state which batch type to use, simply: - .. code-block:: python +.. code-block:: python - from cqlengine.query import BatchType - with BatchQuery(batch_type=BatchType.Unlogged) as b: - LogEntry.batch(b).create(k=1, v=1) - LogEntry.batch(b).create(k=1, v=2) + from cassandra.cqlengine.query import BatchType + with BatchQuery(batch_type=BatchType.Unlogged) as b: + LogEntry.batch(b).create(k=1, v=1) + LogEntry.batch(b).create(k=1, v=2) diff --git a/docs/cqlengine/connections.rst b/docs/cqlengine/connections.rst new file mode 100644 index 0000000000..fd44303514 --- /dev/null +++ b/docs/cqlengine/connections.rst @@ -0,0 +1,137 @@ +=========== +Connections +=========== + +Connections aim to ease the use of multiple sessions with cqlengine. Connections can be set on a model class, per query or using a context manager. + + +Register a new connection +========================= + +To use cqlengine, you need at least a default connection. If you initialize cqlengine's connections with with :func:`connection.setup <.connection.setup>`, a connection will be created automatically. If you want to use another cluster/session, you need to register a new cqlengine connection. You register a connection with :func:`~.connection.register_connection`: + +.. code-block:: python + + from cassandra.cqlengine import connection + + connection.setup(['127.0.0.1') + connection.register_connection('cluster2', ['127.0.0.2']) + +:func:`~.connection.register_connection` can take a list of hosts, as shown above, in which case it will create a connection with a new session. It can also take a `session` argument if you've already created a session: + +.. code-block:: python + + from cassandra.cqlengine import connection + from cassandra.cluster import Cluster + + session = Cluster(['127.0.0.1']).connect() + connection.register_connection('cluster3', session=session) + + +Change the default connection +============================= + +You can change the default cqlengine connection on registration: + +.. code-block:: python + + from cassandra.cqlengine import connection + + connection.register_connection('cluster2', ['127.0.0.2'] default=True) + +or on the fly using :func:`~.connection.set_default_connection` + +.. code-block:: python + + connection.set_default_connection('cluster2') + +Unregister a connection +======================= + +You can unregister a connection using :func:`~.connection.unregister_connection`: + +.. code-block:: python + + connection.unregister_connection('cluster2') + +Management +========== + +When using multiples connections, you also need to sync your models on all connections (and keyspaces) that you need operate on. Management commands have been improved to ease this part. Here is an example: + +.. code-block:: python + + from cassandra.cqlengine import management + + keyspaces = ['ks1', 'ks2'] + conns = ['cluster1', 'cluster2'] + + # registers your connections + # ... + + # create all keyspaces on all connections + for ks in keyspaces: + management.create_simple_keyspace(ks, connections=conns) + + # define your Automobile model + # ... + + # sync your models + management.sync_table(Automobile, keyspaces=keyspaces, connections=conns) + + +Connection Selection +==================== + +cqlengine will select the default connection, unless your specify a connection using one of the following methods. + +Default Model Connection +------------------------ + +You can specify a default connection per model: + +.. code-block:: python + + class Automobile(Model): + __keyspace__ = 'test' + __connection__ = 'cluster2' + manufacturer = columns.Text(primary_key=True) + year = columns.Integer(primary_key=True) + model = columns.Text(primary_key=True) + + print(len(Automobile.objects.all())) # executed on the connection 'cluster2' + +QuerySet and model instance +--------------------------- + +You can use the :attr:`using() <.query.ModelQuerySet.using>` method to select a connection (or keyspace): + +.. code-block:: python + + Automobile.objects.using(connection='cluster1').create(manufacturer='honda', year=2010, model='civic') + q = Automobile.objects.filter(manufacturer='Tesla') + autos = q.using(keyspace='ks2', connection='cluster2').all() + + for auto in autos: + auto.using(connection='cluster1').save() + +Context Manager +--------------- + +You can use the ContextQuery as well to select a connection: + +.. code-block:: python + + with ContextQuery(Automobile, connection='cluster1') as A: + A.objects.filter(manufacturer='honda').all() # executed on 'cluster1' + + +BatchQuery +---------- + +With a BatchQuery, you can select the connection with the context manager. Note that all operations in the batch need to use the same connection. + +.. code-block:: python + + with BatchQuery(connection='cluster1') as b: + Automobile.objects.batch(b).create(manufacturer='honda', year=2010, model='civic') diff --git a/docs/cqlengine/faq.rst b/docs/cqlengine/faq.rst index dcaefae22c..6c056d02ea 100644 --- a/docs/cqlengine/faq.rst +++ b/docs/cqlengine/faq.rst @@ -14,9 +14,9 @@ Statement Ordering is not supported by CQL3 batches. Therefore, once cassandra needs resolving conflict(Updating the same column in one batch), The algorithm below would be used. - * If timestamps are different, pick the column with the largest timestamp (the value being a regular column or a tombstone) - * If timestamps are the same, and one of the columns in a tombstone ('null') - pick the tombstone - * If timestamps are the same, and none of the columns are tombstones, pick the column with the largest value +* If timestamps are different, pick the column with the largest timestamp (the value being a regular column or a tombstone) +* If timestamps are the same, and one of the columns in a tombstone ('null') - pick the tombstone +* If timestamps are the same, and none of the columns are tombstones, pick the column with the largest value Below is an example to show this scenario. @@ -48,3 +48,20 @@ resolve to the statement with the lastest timestamp. assert MyModel.objects(id=1).first().count == 3 assert MyModel.objects(id=1).first().text == '111' +How can I delete individual values from a row? +------------------------------------------------- + +When inserting with CQLEngine, ``None`` is equivalent to CQL ``NULL`` or to +issuing a ``DELETE`` on that column. For example: + +.. code-block:: python + + class MyModel(Model): + id = columns.Integer(primary_key=True) + text = columns.Text() + + m = MyModel.create(id=1, text='We can delete this with None') + assert MyModel.objects(id=1).first().text is not None + + m.update(text=None) + assert MyModel.objects(id=1).first().text is None diff --git a/docs/cqlengine/models.rst b/docs/cqlengine/models.rst index dffd06fb3f..719513f4a9 100644 --- a/docs/cqlengine/models.rst +++ b/docs/cqlengine/models.rst @@ -119,7 +119,7 @@ extend the model's validation method: if self.name == 'jon': raise ValidationError('no jon\'s allowed') -*Note*: while not required, the convention is to raise a ``ValidationError`` (``from cqlengine import ValidationError``) +*Note*: while not required, the convention is to raise a ``ValidationError`` (``from cassandra.cqlengine import ValidationError``) if validation fails. .. _model_inheritance: @@ -201,7 +201,7 @@ are only created, presisted, and queried via table Models. A short example to in users.create(name="Joe", addr=address(street="Easy St.", zipcode=99999)) user = users.objects(name="Joe")[0] - print user.name, user.addr + print(user.name, user.addr) # Joe address(street=u'Easy St.', zipcode=99999) UDTs are modeled by inheriting :class:`~.usertype.UserType`, and setting column type attributes. Types are then used in defining diff --git a/docs/cqlengine/queryset.rst b/docs/cqlengine/queryset.rst index 18287f924d..fa99585141 100644 --- a/docs/cqlengine/queryset.rst +++ b/docs/cqlengine/queryset.rst @@ -6,289 +6,311 @@ Making Queries Retrieving objects ================== - Once you've populated Cassandra with data, you'll probably want to retrieve some of it. This is accomplished with QuerySet objects. This section will describe how to use QuerySet objects to retrieve the data you're looking for. +Once you've populated Cassandra with data, you'll probably want to retrieve some of it. This is accomplished with QuerySet objects. This section will describe how to use QuerySet objects to retrieve the data you're looking for. Retrieving all objects ---------------------- - The simplest query you can make is to return all objects from a table. +The simplest query you can make is to return all objects from a table. - This is accomplished with the ``.all()`` method, which returns a QuerySet of all objects in a table +This is accomplished with the ``.all()`` method, which returns a QuerySet of all objects in a table - Using the Person example model, we would get all Person objects like this: +Using the Person example model, we would get all Person objects like this: - .. code-block:: python +.. code-block:: python - all_objects = Person.objects.all() + all_objects = Person.objects.all() .. _retrieving-objects-with-filters: Retrieving objects with filters ------------------------------- - Typically, you'll want to query only a subset of the records in your database. +Typically, you'll want to query only a subset of the records in your database. - That can be accomplished with the QuerySet's ``.filter(\*\*)`` method. +That can be accomplished with the QuerySet's ``.filter(\*\*)`` method. - For example, given the model definition: +For example, given the model definition: - .. code-block:: python +.. code-block:: python - class Automobile(Model): - manufacturer = columns.Text(primary_key=True) - year = columns.Integer(primary_key=True) - model = columns.Text() - price = columns.Decimal() - options = columns.Set(columns.Text) + class Automobile(Model): + manufacturer = columns.Text(primary_key=True) + year = columns.Integer(primary_key=True) + model = columns.Text() + price = columns.Decimal() + options = columns.Set(columns.Text) - ...and assuming the Automobile table contains a record of every car model manufactured in the last 20 years or so, we can retrieve only the cars made by a single manufacturer like this: +...and assuming the Automobile table contains a record of every car model manufactured in the last 20 years or so, we can retrieve only the cars made by a single manufacturer like this: - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') + q = Automobile.objects.filter(manufacturer='Tesla') - You can also use the more convenient syntax: +You can also use the more convenient syntax: - .. code-block:: python +.. code-block:: python - q = Automobile.objects(Automobile.manufacturer == 'Tesla') + q = Automobile.objects(Automobile.manufacturer == 'Tesla') - We can then further filter our query with another call to **.filter** +We can then further filter our query with another call to **.filter** - .. code-block:: python +.. code-block:: python - q = q.filter(year=2012) + q = q.filter(year=2012) - *Note: all queries involving any filtering MUST define either an '=' or an 'in' relation to either a primary key column, or an indexed column.* +*Note: all queries involving any filtering MUST define either an '=' or an 'in' relation to either a primary key column, or an indexed column.* Accessing objects in a QuerySet =============================== - There are several methods for getting objects out of a queryset +There are several methods for getting objects out of a queryset - * iterating over the queryset - .. code-block:: python +* iterating over the queryset + .. code-block:: python - for car in Automobile.objects.all(): - #...do something to the car instance - pass + for car in Automobile.objects.all(): + #...do something to the car instance + pass - * list index - .. code-block:: python +* list index + .. code-block:: python - q = Automobile.objects.all() - q[0] #returns the first result - q[1] #returns the second result + q = Automobile.objects.all() + q[0] #returns the first result + q[1] #returns the second result - .. note:: + .. note:: - * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array indexing will load every result up to the index value requested - * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. + * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array indexing will load every result up to the index value requested + * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. - * list slicing - .. code-block:: python +* list slicing + .. code-block:: python - q = Automobile.objects.all() - q[1:] #returns all results except the first - q[1:9] #returns a slice of the results + q = Automobile.objects.all() + q[1:] #returns all results except the first + q[1:9] #returns a slice of the results - .. note:: + .. note:: - * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array slicing will load every result up to the index value requested - * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. + * CQL does not support specifying a start position in it's queries. Therefore, accessing elements using array slicing will load every result up to the index value requested + * Using negative indices requires a "SELECT COUNT()" to be executed. This has a performance cost on large datasets. - * calling :attr:`get() ` on the queryset - .. code-block:: python +* calling :attr:`get() ` on the queryset + .. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year=2012) - car = q.get() + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) + car = q.get() - this returns the object matching the queryset + this returns the object matching the queryset - * calling :attr:`first() ` on the queryset - .. code-block:: python +* calling :attr:`first() ` on the queryset + .. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year=2012) - car = q.first() + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) + car = q.first() - this returns the first value in the queryset + this returns the first value in the queryset .. _query-filtering-operators: Filtering Operators =================== - :attr:`Equal To ` +:attr:`Equal To ` + +The default filtering operator. + +.. code-block:: python + + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year=2012) #year == 2012 + +In addition to simple equal to queries, cqlengine also supports querying with other operators by appending a ``__`` to the field name on the filtering call + +:attr:`in (__in) ` - The default filtering operator. +.. code-block:: python - .. code-block:: python + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__in=[2011, 2012]) - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year=2012) #year == 2012 - In addition to simple equal to queries, cqlengine also supports querying with other operators by appending a ``__`` to the field name on the filtering call +:attr:`> (__gt) ` - :attr:`in (__in) ` +.. code-block:: python - .. code-block:: python + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__gt=2010) # year > 2010 - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year__in=[2011, 2012]) + # or the nicer syntax + q.filter(Automobile.year > 2010) - :attr:`> (__gt) ` +:attr:`>= (__gte) ` - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year__gt=2010) # year > 2010 + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__gte=2010) # year >= 2010 - # or the nicer syntax + # or the nicer syntax - q.filter(Automobile.year > 2010) + q.filter(Automobile.year >= 2010) - :attr:`>= (__gte) ` +:attr:`< (__lt) ` - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year__gte=2010) # year >= 2010 + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__lt=2012) # year < 2012 - # or the nicer syntax + # or... - q.filter(Automobile.year >= 2010) + q.filter(Automobile.year < 2012) - :attr:`< (__lt) ` +:attr:`<= (__lte) ` - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year__lt=2012) # year < 2012 + q = Automobile.objects.filter(manufacturer='Tesla') + q = q.filter(year__lte=2012) # year <= 2012 - # or... + q.filter(Automobile.year <= 2012) - q.filter(Automobile.year < 2012) +:attr:`CONTAINS (__contains) ` - :attr:`<= (__lte) ` +The CONTAINS operator is available for all collection types (List, Set, Map). - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q = q.filter(year__lte=2012) # year <= 2012 + q = Automobile.objects.filter(manufacturer='Tesla') + q.filter(options__contains='backup camera').allow_filtering() - q.filter(Automobile.year <= 2012) +Note that we need to use allow_filtering() since the *options* column has no secondary index. - :attr:`CONTAINS (__contains) ` +:attr:`LIKE (__like) ` - The CONTAINS operator is available for all collection types (List, Set, Map). +The LIKE operator is available for text columns that have a SASI secondary index. - .. code-block:: python +.. code-block:: python - q = Automobile.objects.filter(manufacturer='Tesla') - q.filter(options__contains='backup camera').allow_filtering() + q = Automobile.objects.filter(model__like='%Civic%').allow_filtering() - Note that we need to use allow_filtering() since the *options* column has no secondary index. +:attr:`IS NOT NULL (IsNotNull(column_name)) ` + +The IS NOT NULL operator is not yet supported for C*. + +.. code-block:: python + + q = Automobile.objects.filter(IsNotNull('model')) + +Limitations: + +- Currently, cqlengine does not support SASI index creation. To use this feature, you need to create the SASI index using the core driver. +- Queries using LIKE must use allow_filtering() since the *model* column has no standard secondary index. Note that the server will use the SASI index properly when executing the query. TimeUUID Functions ================== - In addition to querying using regular values, there are two functions you can pass in when querying TimeUUID columns to help make filtering by them easier. Note that these functions don't actually return a value, but instruct the cql interpreter to use the functions in it's query. +In addition to querying using regular values, there are two functions you can pass in when querying TimeUUID columns to help make filtering by them easier. Note that these functions don't actually return a value, but instruct the cql interpreter to use the functions in it's query. - .. class:: MinTimeUUID(datetime) +.. class:: MinTimeUUID(datetime) - returns the minimum time uuid value possible for the given datetime + returns the minimum time uuid value possible for the given datetime - .. class:: MaxTimeUUID(datetime) +.. class:: MaxTimeUUID(datetime) - returns the maximum time uuid value possible for the given datetime + returns the maximum time uuid value possible for the given datetime - *Example* +*Example* - .. code-block:: python +.. code-block:: python - class DataStream(Model): - time = cqlengine.TimeUUID(primary_key=True) - data = cqlengine.Bytes() + class DataStream(Model): + id = columns.UUID(partition_key=True) + time = columns.TimeUUID(primary_key=True) + data = columns.Bytes() - min_time = datetime(1982, 1, 1) - max_time = datetime(1982, 3, 9) + min_time = datetime(1982, 1, 1) + max_time = datetime(1982, 3, 9) - DataStream.filter(time__gt=cqlengine.MinTimeUUID(min_time), time__lt=cqlengine.MaxTimeUUID(max_time)) + DataStream.filter(time__gt=functions.MinTimeUUID(min_time), time__lt=functions.MaxTimeUUID(max_time)) Token Function ============== - Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys). - Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows. +Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys). +Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows. - See http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun +See http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun - *Example* +*Example* - .. code-block:: python +.. code-block:: python - class Items(Model): - id = cqlengine.Text(primary_key=True) - data = cqlengine.Bytes() + class Items(Model): + id = columns.Text(primary_key=True) + data = columns.Bytes() - query = Items.objects.all().limit(10) + query = Items.objects.all().limit(10) - first_page = list(query); - last = first_page[-1] - next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk))) + first_page = list(query); + last = first_page[-1] + next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk))) QuerySets are immutable ======================= - When calling any method that changes a queryset, the method does not actually change the queryset object it's called on, but returns a new queryset object with the attributes of the original queryset, plus the attributes added in the method call. +When calling any method that changes a queryset, the method does not actually change the queryset object it's called on, but returns a new queryset object with the attributes of the original queryset, plus the attributes added in the method call. - *Example* +*Example* - .. code-block:: python +.. code-block:: python - #this produces 3 different querysets - #q does not change after it's initial definition - q = Automobiles.objects.filter(year=2012) - tesla2012 = q.filter(manufacturer='Tesla') - honda2012 = q.filter(manufacturer='Honda') + #this produces 3 different querysets + #q does not change after it's initial definition + q = Automobiles.objects.filter(year=2012) + tesla2012 = q.filter(manufacturer='Tesla') + honda2012 = q.filter(manufacturer='Honda') Ordering QuerySets ================== - Since Cassandra is essentially a distributed hash table on steroids, the order you get records back in will not be particularly predictable. +Since Cassandra is essentially a distributed hash table on steroids, the order you get records back in will not be particularly predictable. - However, you can set a column to order on with the ``.order_by(column_name)`` method. +However, you can set a column to order on with the ``.order_by(column_name)`` method. - *Example* +*Example* - .. code-block:: python +.. code-block:: python - #sort ascending - q = Automobiles.objects.all().order_by('year') - #sort descending - q = Automobiles.objects.all().order_by('-year') + #sort ascending + q = Automobiles.objects.all().order_by('year') + #sort descending + q = Automobiles.objects.all().order_by('-year') - *Note: Cassandra only supports ordering on a clustering key. In other words, to support ordering results, your model must have more than one primary key, and you must order on a primary key, excluding the first one.* +*Note: Cassandra only supports ordering on a clustering key. In other words, to support ordering results, your model must have more than one primary key, and you must order on a primary key, excluding the first one.* - *For instance, given our Automobile model, year is the only column we can order on.* +*For instance, given our Automobile model, year is the only column we can order on.* Values Lists ============ - There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses. - Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example: +There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses. +Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example: - .. code-block:: python +.. code-block:: python - items = list(range(20)) - random.shuffle(items) - for i in items: - TestModel.create(id=1, clustering_key=i) + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) - values = list(TestModel.objects.values_list('clustering_key', flat=True)) - # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] Per Query Timeouts =================== @@ -299,47 +321,83 @@ A timeout is specified in seconds and can be an int, float or None. None means no timeout. - .. code-block:: python +.. code-block:: python - class Row(Model): - id = columns.Integer(primary_key=True) - name = columns.Text() + class Row(Model): + id = columns.Integer(primary_key=True) + name = columns.Text() - Fetch all objects with a timeout of 5 seconds +Fetch all objects with a timeout of 5 seconds - .. code-block:: python +.. code-block:: python - Row.objects().timeout(5).all() + Row.objects().timeout(5).all() - Create a single row with a 50ms timeout +Create a single row with a 50ms timeout - .. code-block:: python +.. code-block:: python - Row(id=1, name='Jon').timeout(0.05).create() + Row(id=1, name='Jon').timeout(0.05).create() - Delete a single row with no timeout +Delete a single row with no timeout - .. code-block:: python +.. code-block:: python - Row(id=1).timeout(None).delete() + Row(id=1).timeout(None).delete() - Update a single row with no timeout +Update a single row with no timeout - .. code-block:: python +.. code-block:: python - Row(id=1).timeout(None).update(name='Blake') + Row(id=1).timeout(None).update(name='Blake') - Batch query timeouts +Batch query timeouts - .. code-block:: python +.. code-block:: python + + with BatchQuery(timeout=10) as b: + Row(id=1, name='Jon').create() + + +NOTE: You cannot set both timeout and batch at the same time, batch will use the timeout defined in it's constructor. +Setting the timeout on the model is meaningless and will raise an AssertionError. - with BatchQuery(timeout=10) as b: - Row(id=1, name='Jon').create() +.. _ttl-change: - NOTE: You cannot set both timeout and batch at the same time, batch will use the timeout defined in it's constructor. - Setting the timeout on the model is meaningless and will raise an AssertionError. +Default TTL and Per Query TTL +============================= + +Model default TTL now relies on the *default_time_to_live* feature, introduced in Cassandra 2.0. It is not handled anymore in the CQLEngine Model (cassandra-driver >=3.6). You can set the default TTL of a table like this: + +Example: + +.. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + +You can set TTL per-query if needed. Here are a some examples: + +Example: + +.. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + + user = User.objects.create(user_id=1) # Default TTL 20 will be set automatically on the server + + user.ttl(30).update(age=21) # Update the TTL to 30 + User.objects.ttl(10).create(user_id=1) # TTL 10 + User(user_id=1, age=21).ttl(10).save() # TTL 10 Named Tables @@ -348,14 +406,14 @@ Named Tables Named tables are a way of querying a table without creating an class. They're useful for querying system tables or exploring an unfamiliar database. - .. code-block:: python +.. code-block:: python - from cqlengine.connection import setup - setup("127.0.0.1", "cqlengine_test") + from cassandra.cqlengine.connection import setup + setup("127.0.0.1", "cqlengine_test") - from cqlengine.named import NamedTable - user = NamedTable("cqlengine_test", "user") - user.objects() - user.objects()[0] + from cassandra.cqlengine.named import NamedTable + user = NamedTable("cqlengine_test", "user") + user.objects() + user.objects()[0] - # {u'pk': 1, u't': datetime.datetime(2014, 6, 26, 17, 10, 31, 774000)} + # {u'pk': 1, u't': datetime.datetime(2014, 6, 26, 17, 10, 31, 774000)} diff --git a/docs/cqlengine/third_party.rst b/docs/cqlengine/third_party.rst index c4c99dbf54..20c26df304 100644 --- a/docs/cqlengine/third_party.rst +++ b/docs/cqlengine/third_party.rst @@ -13,11 +13,11 @@ Here's how, in substance, CQLengine can be plugged to `Celery from celery import Celery from celery.signals import worker_process_init, beat_init - from cqlengine import connection - from cqlengine.connection import ( + from cassandra.cqlengine import connection + from cassandra.cqlengine.connection import ( cluster as cql_cluster, session as cql_session) - def cassandra_init(): + def cassandra_init(**kwargs): """ Initialize a clean Cassandra connection. """ if cql_cluster is not None: cql_cluster.shutdown() @@ -40,8 +40,8 @@ This is the code required for proper connection handling of CQLengine for a .. code-block:: python - from cqlengine import connection - from cqlengine.connection import ( + from cassandra.cqlengine import connection + from cassandra.cqlengine.connection import ( cluster as cql_cluster, session as cql_session) try: @@ -52,7 +52,7 @@ This is the code required for proper connection handling of CQLengine for a pass else: @postfork - def cassandra_init(): + def cassandra_init(**kwargs): """ Initialize a new Cassandra session in the context. Ensures that a new session is returned for every new request. diff --git a/docs/cqlengine/upgrade_guide.rst b/docs/cqlengine/upgrade_guide.rst index ee524cc7f8..5b0ab39360 100644 --- a/docs/cqlengine/upgrade_guide.rst +++ b/docs/cqlengine/upgrade_guide.rst @@ -40,7 +40,7 @@ Imports cqlengine is now integrated as a sub-package of the driver base package 'cassandra'. Upgrading will require adjusting imports to cqlengine. For example:: - from cqlengine import columns + from cassandra.cqlengine import columns is now:: diff --git a/docs/dates_and_times.rst b/docs/dates_and_times.rst new file mode 100644 index 0000000000..7a89f77437 --- /dev/null +++ b/docs/dates_and_times.rst @@ -0,0 +1,87 @@ +Working with Dates and Times +============================ + +This document is meant to provide on overview of the assumptions and limitations of the driver time handling, the +reasoning behind it, and describe approaches to working with these types. + +timestamps (Cassandra DateType) +------------------------------- + +Timestamps in Cassandra are timezone-naive timestamps encoded as millseconds since UNIX epoch. Clients working with +timestamps in this database usually find it easiest to reason about them if they are always assumed to be UTC. To quote the +pytz documentation, "The preferred way of dealing with times is to always work in UTC, converting to localtime only when +generating output to be read by humans." The driver adheres to this tenant, and assumes UTC is always in the database. The +driver attempts to make this correct on the way in, and assumes no timezone on the way out. + +Write Path +~~~~~~~~~~ +When inserting timestamps, the driver handles serialization for the write path as follows: + +If the input is a ``datetime.datetime``, the serialization is normalized by starting with the ``utctimetuple()`` of the +value. + +- If the ``datetime`` object is timezone-aware, the timestamp is shifted, and represents the UTC timestamp equivalent. +- If the ``datetime`` object is timezone-naive, this results in no shift -- any ``datetime`` with no timezone information is assumed to be UTC + +Note the second point above applies even to "local" times created using ``now()``:: + + >>> d = datetime.now() + + >>> print(d.tzinfo) + None + + +These do not contain timezone information intrinsically, so they will be assumed to be UTC and not shifted. When generating +timestamps in the application, it is clearer to use ``datetime.utcnow()`` to be explicit about it. + +If the input for a timestamp is numeric, it is assumed to be a epoch-relative millisecond timestamp, as specified in the +CQL spec -- no scaling or conversion is done. + +Read Path +~~~~~~~~~ +The driver always assumes persisted timestamps are UTC and makes no attempt to localize them. Returned values are +timezone-naive ``datetime.datetime``. We follow this approach because the datetime API has deficiencies around daylight +saving time, and the defacto package for handling this is a third-party package (we try to minimize external dependencies +and not make decisions for the integrator). + +The decision for how to handle timezones is left to the application. For the most part it is straightforward to apply +localization to the ``datetime``\s returned by queries. One prevalent method is to use pytz for localization:: + + import pytz + user_tz = pytz.timezone('US/Central') + timestamp_naive = row.ts + timestamp_utc = pytz.utc.localize(timestamp_naive) + timestamp_presented = timestamp_utc.astimezone(user_tz) + +This is the most robust approach (likely refactored into a function). If it is deemed too cumbersome to apply for all call +sites in the application, it is possible to patch the driver with custom deserialization for this type. However, doing +this depends depends some on internal APIs and what extensions are present, so we will only mention the possibility, and +not spell it out here. + +date, time (Cassandra DateType) +------------------------------- +Date and time in Cassandra are idealized markers, much like ``datetime.date`` and ``datetime.time`` in the Python standard +library. Unlike these Python implementations, the Cassandra encoding supports much wider ranges. To accommodate these +ranges without overflow, this driver returns these data in custom types: :class:`.util.Date` and :class:`.util.Time`. + +Write Path +~~~~~~~~~~ +For simple (not prepared) statements, the input values for each of these can be either a string literal or an encoded +integer. See `Working with dates `_ +or `Working with time `_ for details +on the encoding or string formats. + +For prepared statements, the driver accepts anything that can be used to construct the :class:`.util.Date` or +:class:`.util.Time` classes. See the linked API docs for details. + +Read Path +~~~~~~~~~ +The driver always returns custom types for ``date`` and ``time``. + +The driver returns :class:`.util.Date` for ``date`` in order to accommodate the wider range of values without overflow. +For applications working within the supported range of [``datetime.MINYEAR``, ``datetime.MAXYEAR``], these are easily +converted to standard ``datetime.date`` insances using :meth:`.Date.date`. + +The driver returns :class:`.util.Time` for ``time`` in order to retain nanosecond precision stored in the database. +For applications not concerned with this level of precision, these are easily converted to standard ``datetime.time`` +insances using :meth:`.Time.time`. diff --git a/docs/execution_profiles.rst b/docs/execution_profiles.rst new file mode 100644 index 0000000000..0965d77f3d --- /dev/null +++ b/docs/execution_profiles.rst @@ -0,0 +1,156 @@ +Execution Profiles +================== + +Execution profiles aim at making it easier to execute requests in different ways within +a single connected ``Session``. Execution profiles are being introduced to deal with the exploding number of +configuration options, especially as the database platform evolves more complex workloads. + +The legacy configuration remains intact, but legacy and Execution Profile APIs +cannot be used simultaneously on the same client ``Cluster``. Legacy configuration +will be removed in the next major release (4.0). + +An execution profile and its parameters should be unique across ``Cluster`` instances. +For example, an execution profile and its ``LoadBalancingPolicy`` should +not be applied to more than one ``Cluster`` instance. + +This document explains how Execution Profiles relate to existing settings, and shows how to use the new profiles for +request execution. + +Mapping Legacy Parameters to Profiles +------------------------------------- + +Execution profiles can inherit from :class:`.cluster.ExecutionProfile`, and currently provide the following options, +previously input from the noted attributes: + +- load_balancing_policy - :attr:`.Cluster.load_balancing_policy` +- request_timeout - :attr:`.Session.default_timeout`, optional :meth:`.Session.execute` parameter +- retry_policy - :attr:`.Cluster.default_retry_policy`, optional :attr:`.Statement.retry_policy` attribute +- consistency_level - :attr:`.Session.default_consistency_level`, optional :attr:`.Statement.consistency_level` attribute +- serial_consistency_level - :attr:`.Session.default_serial_consistency_level`, optional :attr:`.Statement.serial_consistency_level` attribute +- row_factory - :attr:`.Session.row_factory` attribute + +When using the new API, these parameters can be defined by instances of :class:`.cluster.ExecutionProfile`. + +Using Execution Profiles +------------------------ +Default +~~~~~~~ + +.. code:: python + + from cassandra.cluster import Cluster + cluster = Cluster() + session = cluster.connect() + local_query = 'SELECT rpc_address FROM system.local' + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.1') + + +The default execution profile is built from Cluster parameters and default Session attributes. This profile matches existing default +parameters. + +Initializing cluster with profiles +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + from cassandra.cluster import ExecutionProfile + from cassandra.policies import WhiteListRoundRobinPolicy + + node1_profile = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + node2_profile = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2'])) + + profiles = {'node1': node1_profile, 'node2': node2_profile} + session = Cluster(execution_profiles=profiles).connect() + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile='node1')[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.1') + Row(rpc_address='127.0.0.1') + + +.. code:: python + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile='node2')[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.2') + + +.. code:: python + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.2') + Row(rpc_address='127.0.0.1') + +Note that, even when custom profiles are injected, the default ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` is still +present. To override the default, specify a policy with the :data:`~.cluster.EXEC_PROFILE_DEFAULT` key. + +.. code:: python + + from cassandra.cluster import EXEC_PROFILE_DEFAULT + profile = ExecutionProfile(request_timeout=30) + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile}) + + +Adding named profiles +~~~~~~~~~~~~~~~~~~~~~ + +New profiles can be added constructing from scratch, or deriving from default: + +.. code:: python + + locked_execution = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + node1_profile = 'node1_whitelist' + cluster.add_execution_profile(node1_profile, locked_execution) + + for _ in cluster.metadata.all_hosts(): + print(session.execute(local_query, execution_profile=node1_profile)[0]) + + +.. parsed-literal:: + + Row(rpc_address='127.0.0.1') + Row(rpc_address='127.0.0.1') + +See :meth:`.Cluster.add_execution_profile` for details and optional parameters. + +Passing a profile instance without mapping +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We also have the ability to pass profile instances to be used for execution, but not added to the mapping: + +.. code:: python + + from cassandra.query import tuple_factory + + tmp = session.execution_profile_clone_update('node1', request_timeout=100, row_factory=tuple_factory) + + print(session.execute(local_query, execution_profile=tmp)[0]) + print(session.execute(local_query, execution_profile='node1')[0]) + +.. parsed-literal:: + + ('127.0.0.1',) + Row(rpc_address='127.0.0.1') + +The new profile is a shallow copy, so the ``tmp`` profile shares a load balancing policy with one managed by the cluster. +If reference objects are to be updated in the clone, one would typically set those attributes to a new instance. diff --git a/docs/faq.rst b/docs/faq.rst index 56cb648a24..194d5520e8 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -44,7 +44,7 @@ Since tracing is done asynchronously to the request, this method polls until the >>> result = future.result() >>> trace = future.get_query_trace() >>> for e in trace.events: - >>> print e.source_elapsed, e.description + >>> print(e.source_elapsed, e.description) 0:00:00.000077 Parsing select * from system.local 0:00:00.000153 Preparing statement @@ -67,7 +67,7 @@ With prepared statements, the replicas are obtained by ``routing_key``, based on >>> bound = prepared.bind((1,)) >>> replicas = cluster.metadata.get_replicas(bound.keyspace, bound.routing_key) >>> for h in replicas: - >>> print h.address + >>> print(h.address) 127.0.0.1 127.0.0.2 diff --git a/docs/geo_types.rst b/docs/geo_types.rst new file mode 100644 index 0000000000..f8750d687c --- /dev/null +++ b/docs/geo_types.rst @@ -0,0 +1,39 @@ +DSE Geometry Types +================== +This section shows how to query and work with the geometric types provided by DSE. + +These types are enabled implicitly by creating the Session from :class:`cassandra.cluster.Cluster`. +This module implicitly registers these types for use in the driver. This extension provides +some simple representative types in :mod:`cassandra.util` for inserting and retrieving data:: + + from cassandra.cluster import Cluster + from cassandra.util import Point, LineString, Polygon + session = Cluster().connect() + + session.execute("INSERT INTO ks.geo (k, point, line, poly) VALUES (%s, %s, %s, %s)", + 0, Point(1, 2), LineString(((1, 2), (3, 4))), Polygon(((1, 2), (3, 4), (5, 6)))) + +Queries returning geometric types return the :mod:`dse.util` types. Note that these can easily be used to construct +types from third-party libraries using the common attributes:: + + from shapely.geometry import LineString + shapely_linestrings = [LineString(res.line.coords) for res in session.execute("SELECT line FROM ks.geo")] + +For prepared statements, shapely geometry types can be used interchangeably with the built-in types because their +defining attributes are the same:: + + from shapely.geometry import Point + prepared = session.prepare("UPDATE ks.geo SET point = ? WHERE k = ?") + session.execute(prepared, (0, Point(1.2, 3.4))) + +In order to use shapely types in a CQL-interpolated (non-prepared) query, one must update the encoder with those types, specifying +the same string encoder as set for the internal types:: + + from cassandra import util + from shapely.geometry import Point, LineString, Polygon + + encoder_func = session.encoder.mapping[util.Point] + for t in (Point, LineString, Polygon): + session.encoder.mapping[t] = encoder_func + + session.execute("UPDATE ks.geo SET point = %s where k = %s", (0, Point(1.2, 3.4))) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 2d9c7ea461..432e42ec4f 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -3,13 +3,40 @@ Getting Started First, make sure you have the driver properly :doc:`installed `. -Connecting to Cassandra +Connecting to a Cluster ----------------------- Before we can start executing any queries against a Cassandra cluster we need to setup an instance of :class:`~.Cluster`. As the name suggests, you will typically have one instance of :class:`~.Cluster` for each Cassandra cluster you want to interact with. +First, make sure you have the Cassandra driver properly :doc:`installed `. + +Connecting to Astra ++++++++++++++++++++ + +If you are a DataStax `Astra `_ user, +here is how to connect to your cluster: + +1. Download the secure connect bundle from your Astra account. +2. Connect to your cluster with + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + + cloud_config = { + 'secure_connect_bundle': '/path/to/secure-connect-dbname.zip' + } + auth_provider = PlainTextAuthProvider(username='user', password='pass') + cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider) + session = cluster.connect() + +See `Astra `_ and :doc:`cloud` for more details. + +Connecting to Cassandra ++++++++++++++++++++++++ The simplest way to create a :class:`~.Cluster` is like this: .. code-block:: python @@ -40,15 +67,7 @@ behavior in some other way, this is the place to do it: .. code-block:: python from cassandra.cluster import Cluster - from cassandra.policies import DCAwareRoundRobinPolicy - - cluster = Cluster( - ['10.1.1.3', '10.1.1.4', '10.1.1.5'], - load_balancing_policy=DCAwareRoundRobinPolicy(local_dc='US_EAST'), - port=9042) - - -You can find a more complete list of options in the :class:`~.Cluster` documentation. + cluster = Cluster(['192.168.0.1', '192.168.0.2'], port=..., ssl_context=...) Instantiating a :class:`~.Cluster` does not actually connect us to any nodes. To establish connections and begin executing queries we need a @@ -59,6 +78,8 @@ To establish connections and begin executing queries we need a cluster = Cluster() session = cluster.connect() +Session Keyspace +---------------- The :meth:`~.Cluster.connect()` method takes an optional ``keyspace`` argument which sets the default keyspace for all queries made through that :class:`~.Session`: @@ -67,7 +88,6 @@ which sets the default keyspace for all queries made through that :class:`~.Sess cluster = Cluster() session = cluster.connect('mykeyspace') - You can always change a Session's keyspace using :meth:`~.Session.set_keyspace` or by executing a ``USE `` query: @@ -77,6 +97,41 @@ by executing a ``USE `` query: # or you can do this instead session.execute('USE users') +Execution Profiles +------------------ +Profiles are passed in by ``execution_profiles`` dict. + +In this case we can construct the base ``ExecutionProfile`` passing all attributes: + +.. code-block:: python + + from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT + from cassandra.policies import WhiteListRoundRobinPolicy, DowngradingConsistencyRetryPolicy + from cassandra.query import tuple_factory + + profile = ExecutionProfile( + load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1']), + retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.LOCAL_QUORUM, + serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL, + request_timeout=15, + row_factory=tuple_factory + ) + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: profile}) + session = cluster.connect() + + print(session.execute("SELECT release_version FROM system.local").one()) + +Users are free to setup additional profiles to be used by name: + +.. code-block:: python + + profile_long = ExecutionProfile(request_timeout=30) + cluster = Cluster(execution_profiles={'long': profile_long}) + session = cluster.connect() + session.execute(statement, execution_profile='long') + +Also, parameters passed to ``Session.execute`` or attached to ``Statement``\s are still honored as before. Executing Queries ----------------- @@ -87,7 +142,7 @@ way to execute a query is to use :meth:`~.Session.execute()`: rows = session.execute('SELECT name, age, email FROM users') for user_row in rows: - print user_row.name, user_row.age, user_row.email + print(user_row.name, user_row.age, user_row.email) This will transparently pick a Cassandra node to execute the query against and handle any retries that are necessary if the operation fails. @@ -103,30 +158,62 @@ examples are equivalent: rows = session.execute('SELECT name, age, email FROM users') for row in rows: - print row.name, row.age, row.email + print(row.name, row.age, row.email) .. code-block:: python rows = session.execute('SELECT name, age, email FROM users') for (name, age, email) in rows: - print name, age, email + print(name, age, email) .. code-block:: python rows = session.execute('SELECT name, age, email FROM users') for row in rows: - print row[0], row[1], row[2] + print(row[0], row[1], row[2]) If you prefer another result format, such as a ``dict`` per row, you can change the :attr:`~.Session.row_factory` attribute. -For queries that will be run repeatedly, you should use -`Prepared statements <#prepared-statements>`_. +As mentioned in our `Drivers Best Practices Guide `_, +it is highly recommended to use `Prepared statements <#prepared-statement>`_ for your +frequently run queries. + +.. _prepared-statement: + +Prepared Statements +------------------- +Prepared statements are queries that are parsed by Cassandra and then saved +for later use. When the driver uses a prepared statement, it only needs to +send the values of parameters to bind. This lowers network traffic +and CPU utilization within Cassandra because Cassandra does not have to +re-parse the query each time. + +To prepare a query, use :meth:`.Session.prepare()`: + +.. code-block:: python + + user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") + + users = [] + for user_id in user_ids_to_query: + user = session.execute(user_lookup_stmt, [user_id]) + users.append(user) + +:meth:`~.Session.prepare()` returns a :class:`~.PreparedStatement` instance +which can be used in place of :class:`~.SimpleStatement` instances or literal +string queries. It is automatically prepared against all nodes, and the driver +handles re-preparing against new nodes and restarted nodes when necessary. + +Note that the placeholders for prepared statements are ``?`` characters. This +is different than for simple, non-prepared statements (although future versions +of the driver may use the same placeholders for both). Passing Parameters to CQL Queries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -When executing non-prepared statements, the driver supports two forms of -parameter place-holders: positional and named. +Althought it is not recommended, you can also pass parameters to non-prepared +statements. The driver supports two forms of parameter place-holders: positional +and named. Positional parameters are used with a ``%s`` placeholder. For example, when you execute: @@ -179,7 +266,7 @@ Named place-holders use the ``%(name)s`` form: """ INSERT INTO users (name, credits, user_id, username) VALUES (%(name)s, %(credits)s, %(user_id)s, %(name)s) - """ + """, {'name': "John O'Reilly", 'credits': 42, 'user_id': uuid.uuid1()} ) @@ -271,7 +358,7 @@ For example: try: rows = future.result() user = rows[0] - print user.name, user.age + print(user.name, user.age) except ReadTimeout: log.exception("Query timed out:") @@ -288,7 +375,7 @@ This works well for executing many queries concurrently: # wait for them to complete and use the results for future in futures: rows = future.result() - print rows[0].name + print(rows[0].name) Alternatively, instead of calling :meth:`~.ResponseFuture.result()`, you can attach callback and errback functions through the @@ -328,7 +415,8 @@ replicas of the data you are interacting with need to respond for the query to be considered a success. By default, :attr:`.ConsistencyLevel.LOCAL_ONE` will be used for all queries. -You can specify a different default for the session on :attr:`.Session.default_consistency_level`. +You can specify a different default by setting the :attr:`.ExecutionProfile.consistency_level` +for the execution profile with key :data:`~.cluster.EXEC_PROFILE_DEFAULT`. To specify a different consistency level per request, wrap queries in a :class:`~.SimpleStatement`: @@ -342,34 +430,6 @@ in a :class:`~.SimpleStatement`: consistency_level=ConsistencyLevel.QUORUM) session.execute(query, ('John', 42)) -Prepared Statements -------------------- -Prepared statements are queries that are parsed by Cassandra and then saved -for later use. When the driver uses a prepared statement, it only needs to -send the values of parameters to bind. This lowers network traffic -and CPU utilization within Cassandra because Cassandra does not have to -re-parse the query each time. - -To prepare a query, use :meth:`.Session.prepare()`: - -.. code-block:: python - - user_lookup_stmt = session.prepare("SELECT * FROM users WHERE user_id=?") - - users = [] - for user_id in user_ids_to_query: - user = session.execute(user_lookup_stmt, [user_id]) - users.append(user) - -:meth:`~.Session.prepare()` returns a :class:`~.PreparedStatement` instance -which can be used in place of :class:`~.SimpleStatement` instances or literal -string queries. It is automatically prepared against all nodes, and the driver -handles re-preparing against new nodes and restarted nodes when necessary. - -Note that the placeholders for prepared statements are ``?`` characters. This -is different than for simple, non-prepared statements (although future versions -of the driver may use the same placeholders for both). - Setting a Consistency Level with Prepared Statements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To specify a consistency level for prepared statements, you have two options. @@ -400,3 +460,43 @@ level on that: user3_lookup = user_lookup_stmt.bind([user_id3]) user3_lookup.consistency_level = ConsistencyLevel.ALL user3 = session.execute(user3_lookup) + +Speculative Execution +^^^^^^^^^^^^^^^^^^^^^ + +Speculative execution is a way to minimize latency by preemptively executing several +instances of the same query against different nodes. For more details about this +technique, see `Speculative Execution with DataStax Drivers `_. + +To enable speculative execution: + +* Configure a :class:`~.policies.SpeculativeExecutionPolicy` with the ExecutionProfile +* Mark your query as idempotent, which mean it can be applied multiple + times without changing the result of the initial application. + See `Query Idempotence `_ for more details. + + +Example: + +.. code-block:: python + + from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT + from cassandra.policies import ConstantSpeculativeExecutionPolicy + from cassandra.query import SimpleStatement + + # Configure the speculative execution policy + ep = ExecutionProfile( + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(delay=.5, max_attempts=10) + ) + cluster = Cluster(..., execution_profiles={EXEC_PROFILE_DEFAULT: ep}) + session = cluster.connect() + + # Mark the query idempotent + query = SimpleStatement( + "UPDATE my_table SET list_col = [1] WHERE pk = 1", + is_idempotent=True + ) + + # Execute. A new query will be sent to the server every 0.5 second + # until we receive a response, for a max number attempts of 10. + session.execute(query) diff --git a/docs/graph.rst b/docs/graph.rst new file mode 100644 index 0000000000..47dc53d38d --- /dev/null +++ b/docs/graph.rst @@ -0,0 +1,434 @@ +DataStax Graph Queries +====================== + +The driver executes graph queries over the Cassandra native protocol. Use +:meth:`.Session.execute_graph` or :meth:`.Session.execute_graph_async` for +executing gremlin queries in DataStax Graph. + +The driver defines three Execution Profiles suitable for graph execution: + +* :data:`~.cluster.EXEC_PROFILE_GRAPH_DEFAULT` +* :data:`~.cluster.EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT` +* :data:`~.cluster.EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT` + +See :doc:`getting_started` and :doc:`execution_profiles` +for more detail on working with profiles. + +In DSE 6.8.0, the Core graph engine has been introduced and is now the default. It +provides a better unified multi-model, performance and scale. This guide +is for graphs that use the core engine. If you work with previous versions of +DSE or existing graphs, see :doc:`classic_graph`. + +Getting Started with Graph and the Core Engine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +First, we need to create a graph in the system. To access the system API, we +use the system execution profile :: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT + + cluster = Cluster() + session = cluster.connect() + + graph_name = 'movies' + session.execute_graph("system.graph(name).create()", {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + + +Graphs that use the core engine only support GraphSON3. Since they are Cassandra tables under +the hood, we can automatically configure the execution profile with the proper options +(row_factory and graph_protocol) when executing queries. You only need to make sure that +the `graph_name` is set and GraphSON3 will be automatically used:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + + graph_name = 'movies' + ep = GraphExecutionProfile(graph_options=GraphOptions(graph_name=graph_name)) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + session.execute_graph("g.addV(...)") + + +Note that this graph engine detection is based on the metadata. You might experience +some query errors if the graph has been newly created and is not yet in the metadata. This +would result to a badly configured execution profile. If you really want to avoid that, +configure your execution profile explicitly:: + + from cassandra.cluster import Cluster, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.graph import GraphOptions, GraphProtocol, graph_graphson3_row_factory + + graph_name = 'movies' + ep_graphson3 = GraphExecutionProfile( + row_factory=graph_graphson3_row_factory, + graph_options=GraphOptions( + graph_protocol=GraphProtocol.GRAPHSON_3_0, + graph_name=graph_name)) + + cluster = Cluster(execution_profiles={'core': ep_graphson3}) + session = cluster.connect() + session.execute_graph("g.addV(...)", execution_profile='core') + + +We are ready to configure our graph schema. We will create a simple one for movies:: + + # A Vertex represents a "thing" in the world. + # Create the genre vertex + query = """ + schema.vertexLabel('genre') + .partitionBy('genreId', Int) + .property('name', Text) + .create() + """ + session.execute_graph(query) + + # Create the person vertex + query = """ + schema.vertexLabel('person') + .partitionBy('personId', Int) + .property('name', Text) + .create() + """ + session.execute_graph(query) + + # Create the movie vertex + query = """ + schema.vertexLabel('movie') + .partitionBy('movieId', Int) + .property('title', Text) + .property('year', Int) + .property('country', Text) + .create() + """ + session.execute_graph(query) + + # An edge represents a relationship between two vertices + # Create our edges + queries = """ + schema.edgeLabel('belongsTo').from('movie').to('genre').create(); + schema.edgeLabel('actor').from('movie').to('person').create(); + """ + session.execute_graph(queries) + + # Indexes to execute graph requests efficiently + + # If you have a node with the search workload enabled (solr), use the following: + indexes = """ + schema.vertexLabel('genre').searchIndex() + .by("name") + .create(); + + schema.vertexLabel('person').searchIndex() + .by("name") + .create(); + + schema.vertexLabel('movie').searchIndex() + .by('title') + .by("year") + .create(); + """ + session.execute_graph(indexes) + + # Otherwise, use secondary indexes: + indexes = """ + schema.vertexLabel('genre') + .secondaryIndex('by_genre') + .by('name') + .create() + + schema.vertexLabel('person') + .secondaryIndex('by_name') + .by('name') + .create() + + schema.vertexLabel('movie') + .secondaryIndex('by_title') + .by('title') + .create() + """ + session.execute_graph(indexes) + +Add some edge indexes (materialized views):: + + indexes = """ + schema.edgeLabel('belongsTo') + .from('movie') + .to('genre') + .materializedView('movie__belongsTo__genre_by_in_genreId') + .ifNotExists() + .partitionBy(IN, 'genreId') + .clusterBy(OUT, 'movieId', Asc) + .create() + + schema.edgeLabel('actor') + .from('movie') + .to('person') + .materializedView('movie__actor__person_by_in_personId') + .ifNotExists() + .partitionBy(IN, 'personId') + .clusterBy(OUT, 'movieId', Asc) + .create() + """ + session.execute_graph(indexes) + +Next, we'll add some data:: + + session.execute_graph(""" + g.addV('genre').property('genreId', 1).property('name', 'Action').next(); + g.addV('genre').property('genreId', 2).property('name', 'Drama').next(); + g.addV('genre').property('genreId', 3).property('name', 'Comedy').next(); + g.addV('genre').property('genreId', 4).property('name', 'Horror').next(); + """) + + session.execute_graph(""" + g.addV('person').property('personId', 1).property('name', 'Mark Wahlberg').next(); + g.addV('person').property('personId', 2).property('name', 'Leonardo DiCaprio').next(); + g.addV('person').property('personId', 3).property('name', 'Iggy Pop').next(); + """) + + session.execute_graph(""" + g.addV('movie').property('movieId', 1).property('title', 'The Happening'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 2).property('title', 'The Italian Job'). + property('year', 2003).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 3).property('title', 'Revolutionary Road'). + property('year', 2008).property('country', 'United States').next(); + g.addV('movie').property('movieId', 4).property('title', 'The Man in the Iron Mask'). + property('year', 1998).property('country', 'United States').next(); + + g.addV('movie').property('movieId', 5).property('title', 'Dead Man'). + property('year', 1995).property('country', 'United States').next(); + """) + +Now that our genre, actor and movie vertices are added, we'll create the relationships (edges) between them:: + + session.execute_graph(""" + genre_horror = g.V().hasLabel('genre').has('name', 'Horror').id().next(); + genre_drama = g.V().hasLabel('genre').has('name', 'Drama').id().next(); + genre_action = g.V().hasLabel('genre').has('name', 'Action').id().next(); + + leo = g.V().hasLabel('person').has('name', 'Leonardo DiCaprio').id().next(); + mark = g.V().hasLabel('person').has('name', 'Mark Wahlberg').id().next(); + iggy = g.V().hasLabel('person').has('name', 'Iggy Pop').id().next(); + + the_happening = g.V().hasLabel('movie').has('title', 'The Happening').id().next(); + the_italian_job = g.V().hasLabel('movie').has('title', 'The Italian Job').id().next(); + rev_road = g.V().hasLabel('movie').has('title', 'Revolutionary Road').id().next(); + man_mask = g.V().hasLabel('movie').has('title', 'The Man in the Iron Mask').id().next(); + dead_man = g.V().hasLabel('movie').has('title', 'Dead Man').id().next(); + + g.addE('belongsTo').from(__.V(the_happening)).to(__.V(genre_horror)).next(); + g.addE('belongsTo').from(__.V(the_italian_job)).to(__.V(genre_action)).next(); + g.addE('belongsTo').from(__.V(rev_road)).to(__.V(genre_drama)).next(); + g.addE('belongsTo').from(__.V(man_mask)).to(__.V(genre_drama)).next(); + g.addE('belongsTo').from(__.V(man_mask)).to(__.V(genre_action)).next(); + g.addE('belongsTo').from(__.V(dead_man)).to(__.V(genre_drama)).next(); + + g.addE('actor').from(__.V(the_happening)).to(__.V(mark)).next(); + g.addE('actor').from(__.V(the_italian_job)).to(__.V(mark)).next(); + g.addE('actor').from(__.V(rev_road)).to(__.V(leo)).next(); + g.addE('actor').from(__.V(man_mask)).to(__.V(leo)).next(); + g.addE('actor').from(__.V(dead_man)).to(__.V(iggy)).next(); + """) + +We are all set. You can now query your graph. Here are some examples:: + + # Find all movies of the genre Drama + for r in session.execute_graph(""" + g.V().has('genre', 'name', 'Drama').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of the same genre than the movie 'Dead Man' + for r in session.execute_graph(""" + g.V().has('movie', 'title', 'Dead Man').out('belongsTo').in('belongsTo').valueMap();"""): + print(r) + + # Find all movies of Mark Wahlberg + for r in session.execute_graph(""" + g.V().has('person', 'name', 'Mark Wahlberg').in('actor').valueMap();"""): + print(r) + +To see a more graph examples, see `DataStax Graph Examples `_. + +Graph Types for the Core Engine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here are the supported graph types with their python representations: + +============ ================= +DSE Graph Python Driver +============ ================= +text str +boolean bool +bigint long +int int +smallint int +varint long +double float +float float +uuid UUID +bigdecimal Decimal +duration Duration (cassandra.util) +inet str or IPV4Address/IPV6Address (if available) +timestamp datetime.datetime +date datetime.date +time datetime.time +polygon Polygon +point Point +linestring LineString +blob bytearray, buffer (PY2), memoryview (PY3), bytes (PY3) +list list +map dict +set set or list + (Can return a list due to numerical values returned by Java) +tuple tuple +udt class or namedtuple +============ ================= + +Named Parameters +~~~~~~~~~~~~~~~~ + +Named parameters are passed in a dict to :meth:`.cluster.Session.execute_graph`:: + + result_set = session.execute_graph('[a, b]', {'a': 1, 'b': 2}, execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + [r.value for r in result_set] # [1, 2] + +All python types listed in `Graph Types for the Core Engine`_ can be passed as named parameters and will be serialized +automatically to their graph representation: + +Example:: + + session.execute_graph(""" + g.addV('person'). + property('name', text_value). + property('age', integer_value). + property('birthday', timestamp_value). + property('house_yard', polygon_value).next() + """, { + 'text_value': 'Mike Smith', + 'integer_value': 34, + 'timestamp_value': datetime.datetime(1967, 12, 30), + 'polygon_value': Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + }) + + +As with all Execution Profile parameters, graph options can be set in the cluster default (as shown in the first example) +or specified per execution:: + + ep = session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, + graph_options=GraphOptions(graph_name='something-else')) + session.execute_graph(statement, execution_profile=ep) + +CQL collections, Tuple and UDT +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is a very interesting feature of the core engine: we can use all CQL data types, including +list, map, set, tuple and udt. Here is an example using all these types:: + + query = """ + schema.type('address') + .property('address', Text) + .property('city', Text) + .property('state', Text) + .create(); + """ + session.execute_graph(query) + + # It works the same way than normal CQL UDT, so we + # can create an udt class and register it + class Address(object): + def __init__(self, address, city, state): + self.address = address + self.city = city + self.state = state + + session.cluster.register_user_type(graph_name, 'address', Address) + + query = """ + schema.vertexLabel('person') + .partitionBy('personId', Int) + .property('address', typeOf('address')) + .property('friends', listOf(Text)) + .property('skills', setOf(Text)) + .property('scores', mapOf(Text, Int)) + .property('last_workout', tupleOf(Text, Date)) + .create() + """ + session.execute_graph(query) + + # insertion example + query = """ + g.addV('person') + .property('personId', pid) + .property('address', address) + .property('friends', friends) + .property('skills', skills) + .property('scores', scores) + .property('last_workout', last_workout) + .next() + """ + + session.execute_graph(query, { + 'pid': 3, + 'address': Address('42 Smith St', 'Quebec', 'QC'), + 'friends': ['Al', 'Mike', 'Cathy'], + 'skills': {'food', 'fight', 'chess'}, + 'scores': {'math': 98, 'french': 3}, + 'last_workout': ('CrossFit', datetime.date(2018, 11, 20)) + }) + +Limitations +----------- + +Since Python is not a strongly-typed language and the UDT/Tuple graphson representation is, you might +get schema errors when trying to write numerical data. Example:: + + session.execute_graph(""" + schema.vertexLabel('test_tuple').partitionBy('id', Int).property('t', tupleOf(Text, Bigint)).create() + """) + + session.execute_graph(""" + g.addV('test_tuple').property('id', 0).property('t', t) + """, + {'t': ('Test', 99))} + ) + + # error: [Invalid query] message="Value component 1 is of type int, not bigint" + +This is because the server requires the client to include a GraphSON schema definition +with every UDT or tuple query. In the general case, the driver can't determine what Graph type +is meant by, e.g., an int value, and so it can't serialize the value with the correct type in the schema. +The driver provides some numerical type-wrapper factories that you can use to specify types: + +* :func:`~.to_int` +* :func:`~.to_bigint` +* :func:`~.to_smallint` +* :func:`~.to_float` +* :func:`~.to_double` + +Here's the working example of the case above:: + + from cassandra.graph import to_bigint + + session.execute_graph(""" + g.addV('test_tuple').property('id', 0).property('t', t) + """, + {'t': ('Test', to_bigint(99))} + ) + +Continuous Paging +~~~~~~~~~~~~~~~~~ + +This is another nice feature that comes with the core engine: continuous paging with +graph queries. If all nodes of the cluster are >= DSE 6.8.0, it is automatically +enabled under the hood to get the best performance. If you want to explicitly +enable/disable it, you can do it through the execution profile:: + + # Disable it + ep = GraphExecutionProfile(..., continuous_paging_options=None)) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + + # Enable with a custom max_pages option + ep = GraphExecutionProfile(..., + continuous_paging_options=ContinuousPagingOptions(max_pages=10))) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) diff --git a/docs/graph_fluent.rst b/docs/graph_fluent.rst new file mode 100644 index 0000000000..8d5ad5377d --- /dev/null +++ b/docs/graph_fluent.rst @@ -0,0 +1,415 @@ +DataStax Graph Fluent API +========================= + +The fluent API adds graph features to the core driver: + +* A TinkerPop GraphTraversalSource builder to execute traversals on a DSE cluster +* The ability to execution traversal queries explicitly using execute_graph +* GraphSON serializers for all DSE Graph types. +* DSE Search predicates + +The Graph fluent API depends on Apache TinkerPop and is not installed by default. Make sure +you have the Graph requirements are properly :ref:`installed `. + +You might be interested in reading the :doc:`DataStax Graph Getting Started documentation ` to +understand the basics of creating a graph and its schema. + +Graph Traversal Queries +~~~~~~~~~~~~~~~~~~~~~~~ + +The driver provides :meth:`.Session.execute_graph`, which allows users to execute traversal +query strings. Here is a simple example:: + + session.execute_graph("g.addV('genre').property('genreId', 1).property('name', 'Action').next();") + +Since graph queries can be very complex, working with strings is not very convenient and is +hard to maintain. This fluent API allows you to build Gremlin traversals and write your graph +queries directly in Python. These native traversal queries can be executed explicitly, with +a `Session` object, or implicitly:: + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + + # Create an execution profile, using GraphSON3 for Core graphs + ep_graphson3 = DseGraph.create_execution_profile( + 'my_core_graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep_graphson3}) + session = cluster.connect() + + # Execute a fluent graph query + g = DseGraph.traversal_source(session=session) + g.addV('genre').property('genreId', 1).property('name', 'Action').next() + + # implicit execution caused by iterating over results + for v in g.V().has('genre', 'name', 'Drama').in_('belongsTo').valueMap(): + print(v) + +These :ref:`Python types ` are also supported transparently:: + + g.addV('person').property('name', 'Mike').property('birthday', datetime(1984, 3, 11)). \ + property('house_yard', Polygon(((30, 10), (40, 40), (20, 40), (10, 20), (30, 10))) + +More readings about Gremlin: + +* `DataStax Drivers Fluent API `_ +* `gremlin-python documentation `_ + +Configuring a Traversal Execution Profile +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The fluent api takes advantage of *configuration profiles* to allow +different execution configurations for the various query handlers. Graph traversal +execution requires a custom execution profile to enable Gremlin-bytecode as +query language. With Core graphs, it is important to use GraphSON3. Here is how +to accomplish this configuration: + +.. code-block:: python + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + + # Using GraphSON3 as graph protocol is a requirement with Core graphs. + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + + # For Classic graphs, GraphSON1, GraphSON2 and GraphSON3 (DSE 6.8+) are supported. + ep_classic = DseGraph.create_execution_profile('classic_graph_name') # default is GraphSON2 + + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep, 'classic': ep_classic}) + session = cluster.connect() + + g = DseGraph.traversal_source(session) # Build the GraphTraversalSource + print(g.V().toList()) # Traverse the Graph + +Note that the execution profile created with :meth:`DseGraph.create_execution_profile <.datastax.graph.fluent.DseGraph.create_execution_profile>` cannot +be used for any groovy string queries. + +If you want to change execution property defaults, please see the :doc:`Execution Profile documentation ` +for a more generalized discussion of the API. Graph traversal queries use the same execution profile defined for DSE graph. If you +need to change the default properties, please refer to the :doc:`DSE Graph query documentation page ` + +Explicit Graph Traversal Execution with a DSE Session +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Traversal queries can be executed explicitly using `session.execute_graph` or `session.execute_graph_async`. These functions +return results as DSE graph types. If you are familiar with DSE queries or need async execution, you might prefer that way. +Below is an example of explicit execution. For this example, assume the schema has been generated as above: + +.. code-block:: python + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.datastax.graph import GraphProtocol + from cassandra.datastax.graph.fluent import DseGraph + from pprint import pprint + + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={EXEC_PROFILE_GRAPH_DEFAULT: ep}) + session = cluster.connect() + + g = DseGraph.traversal_source(session=session) + +Convert a traversal to a bytecode query for classic graphs:: + + addV_query = DseGraph.query_from_traversal( + g.addV('genre').property('genreId', 1).property('name', 'Action'), + graph_protocol=GraphProtocol.GRAPHSON_3_0 + ) + v_query = DseGraph.query_from_traversal( + g.V(), + graph_protocol=GraphProtocol.GRAPHSON_3_0) + + for result in session.execute_graph(addV_query): + pprint(result.value) + for result in session.execute_graph(v_query): + pprint(result.value) + +Converting a traversal to a bytecode query for core graphs require some more work, because we +need the cluster context for UDT and tuple types: + +.. code-block:: python + context = { + 'cluster': cluster, + 'graph_name': 'the_graph_for_the_query' + } + addV_query = DseGraph.query_from_traversal( + g.addV('genre').property('genreId', 1).property('name', 'Action'), + graph_protocol=GraphProtocol.GRAPHSON_3_0, + context=context + ) + + for result in session.execute_graph(addV_query): + pprint(result.value) + +Implicit Graph Traversal Execution with TinkerPop +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Using the :class:`DseGraph <.datastax.graph.fluent.DseGraph>` class, you can build a GraphTraversalSource +that will execute queries on a DSE session without explicitly passing anything to +that session. We call this *implicit execution* because the `Session` is not +explicitly involved. Everything is managed internally by TinkerPop while +traversing the graph and the results are TinkerPop types as well. + +Synchronous Example +------------------- + +.. code-block:: python + + # Build the GraphTraversalSource + g = DseGraph.traversal_source(session) + # implicitly execute the query by traversing the TraversalSource + g.addV('genre').property('genreId', 1).property('name', 'Action').next() + + # blocks until the query is completed and return the results + results = g.V().toList() + pprint(results) + +Asynchronous Exemple +-------------------- + +You can execute a graph traversal query asynchronously by using `.promise()`. It returns a +python `Future `_. + +.. code-block:: python + + # Build the GraphTraversalSource + g = DseGraph.traversal_source(session) + # implicitly execute the query by traversing the TraversalSource + g.addV('genre').property('genreId', 1).property('name', 'Action').next() # not async + + # get a future and wait + future = g.V().promise() + results = list(future.result()) + pprint(results) + + # or set a callback + def cb(f): + results = list(f.result()) + pprint(results) + future = g.V().promise() + future.add_done_callback(cb) + # do other stuff... + +Specify the Execution Profile explicitly +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you don't want to change the default graph execution profile (`EXEC_PROFILE_GRAPH_DEFAULT`), you can register a new +one as usual and use it explicitly. Here is an example: + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + cluster = Cluster() + ep = DseGraph.create_execution_profile('graph_name', graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster.add_execution_profile('graph_traversal', ep) + session = cluster.connect() + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal(g.V()) + session.execute_graph(query, execution_profile='graph_traversal') + +You can also create multiple GraphTraversalSources and use them with +the same execution profile (for different graphs): + +.. code-block:: python + + g_movies = DseGraph.traversal_source(session, graph_name='movies', ep) + g_series = DseGraph.traversal_source(session, graph_name='series', ep) + + print(g_movies.V().toList()) # Traverse the movies Graph + print(g_series.V().toList()) # Traverse the series Graph + +Batch Queries +~~~~~~~~~~~~~ + +DSE Graph supports batch queries using a :class:`TraversalBatch <.datastax.graph.fluent.query.TraversalBatch>` object +instantiated with :meth:`DseGraph.batch <.datastax.graph.fluent.DseGraph.batch>`. A :class:`TraversalBatch <.datastax.graph.fluent.query.TraversalBatch>` allows +you to execute multiple graph traversals in a single atomic transaction. A +traversal batch is executed with :meth:`.Session.execute_graph` or using +:meth:`TraversalBatch.execute <.datastax.graph.fluent.query.TraversalBatch.execute>` if bounded to a DSE session. + +Either way you choose to execute the traversal batch, you need to configure +the execution profile accordingly. Here is a example:: + + from cassandra.cluster import Cluster + from cassandra.datastax.graph.fluent import DseGraph + + ep = DseGraph.create_execution_profile( + 'graph_name', + graph_protocol=GraphProtocol.GRAPHSON_3_0) + cluster = Cluster(execution_profiles={'graphson3': ep}) + session = cluster.connect() + + g = DseGraph.traversal_source() + +To execute the batch using :meth:`.Session.execute_graph`, you need to convert +the batch to a GraphStatement:: + + batch = DseGraph.batch() + + batch.add( + g.addV('genre').property('genreId', 1).property('name', 'Action')) + batch.add( + g.addV('genre').property('genreId', 2).property('name', 'Drama')) # Don't use `.next()` with a batch + + graph_statement = batch.as_graph_statement(graph_protocol=GraphProtocol.GRAPHSON_3_0) + graph_statement.is_idempotent = True # configure any Statement parameters if needed... + session.execute_graph(graph_statement, execution_profile='graphson3') + +To execute the batch using :meth:`TraversalBatch.execute <.datastax.graph.fluent.query.TraversalBatch.execute>`, you need to bound the batch to a DSE session:: + + batch = DseGraph.batch(session, 'graphson3') # bound the session and execution profile + + batch.add( + g.addV('genre').property('genreId', 1).property('name', 'Action')) + batch.add( + g.addV('genre').property('genreId', 2).property('name', 'Drama')) # Don't use `.next()` with a batch + + batch.execute() + +DSL (Domain Specific Languages) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DSL are very useful to write better domain-specific APIs and avoiding +code duplication. Let's say we have a graph of `People` and we produce +a lot of statistics based on age. All graph traversal queries of our +application would look like:: + + g.V().hasLabel("people").has("age", P.gt(21))... + + +which is not really verbose and quite annoying to repeat in a code base. Let's create a DSL:: + + from gremlin_python.process.graph_traversal import GraphTraversal, GraphTraversalSource + + class MyAppTraversal(GraphTraversal): + + def younger_than(self, age): + return self.has("age", P.lt(age)) + + def older_than(self, age): + return self.has("age", P.gt(age)) + + + class MyAppTraversalSource(GraphTraversalSource): + + def __init__(self, *args, **kwargs): + super(MyAppTraversalSource, self).__init__(*args, **kwargs) + self.graph_traversal = MyAppTraversal + + def people(self): + return self.get_graph_traversal().V().hasLabel("people") + +Now, we can use our DSL that is a lot cleaner:: + + from cassandra.datastax.graph.fluent import DseGraph + + # ... + g = DseGraph.traversal_source(session=session, traversal_class=MyAppTraversalsource) + + g.people().younger_than(21)... + g.people().older_than(30)... + +To see a more complete example of DSL, see the `Python killrvideo DSL app `_ + +Search +~~~~~~ + +DSE Graph can use search indexes that take advantage of DSE Search functionality for +efficient traversal queries. Here are the list of additional search predicates: + +Text tokenization: + +* :meth:`token <.datastax.graph.fluent.predicates.Search.token>` +* :meth:`token_prefix <.datastax.graph.fluent.predicates.Search.token_prefix>` +* :meth:`token_regex <.datastax.graph.fluent.predicates.Search.token_regex>` +* :meth:`token_fuzzy <.datastax.graph.fluent.predicates.Search.token_fuzzy>` + +Text match: + +* :meth:`prefix <.datastax.graph.fluent.predicates.Search.prefix>` +* :meth:`regex <.datastax.graph.fluent.predicates.Search.regex>` +* :meth:`fuzzy <.datastax.graph.fluent.predicates.Search.fuzzy>` +* :meth:`phrase <.datastax.graph.fluent.predicates.Search.phrase>` + +Geo: + +* :meth:`inside <.datastax.graph.fluent.predicates.Geo.inside>` + +Create search indexes +--------------------- + +For text tokenization: + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('text_field').asText().add()") + +For text match: + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('text_field').asString().add()") + + +For geospatial: + +You can create a geospatial index on Point and LineString fields. + +.. code-block:: python + + + s.execute_graph("schema.vertexLabel('my_vertex_label').index('search').search().by('point_field').add()") + + +Using search indexes +-------------------- + +Token: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Search + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','text_field', Search.token_regex('Hello.+World')).values('text_field')) + session.execute_graph(query) + +Text: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Search + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','text_field', Search.prefix('Hello')).values('text_field')) + session.execute_graph(query) + +Geospatial: + +.. code-block:: python + + from cassandra.datastax.graph.fluent.predicates import Geo + from cassandra.util import Distance + # ... + + g = DseGraph.traversal_source() + query = DseGraph.query_from_traversal( + g.V().has('my_vertex_label','point_field', Geo.inside(Distance(46, 71, 100)).values('point_field')) + session.execute_graph(query) + + +For more details, please refer to the official `DSE Search Indexes Documentation `_ diff --git a/docs/index.rst b/docs/index.rst index 903a750666..2fcaf43884 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,50 +1,79 @@ -Python Cassandra Driver -======================= -A Python client driver for `Apache Cassandra `_. +DataStax Python Driver for Apache Cassandra® +============================================ +A Python client driver for `Apache Cassandra® `_. This driver works exclusively with the Cassandra Query Language v3 (CQL3) -and Cassandra's native protocol. Cassandra 1.2+ is supported. +and Cassandra's native protocol. Cassandra 2.1+ is supported, including DSE 4.7+. -The driver supports Python 2.6, 2.7, 3.3, and 3.4. +The driver supports Python 3.9 through 3.13. This driver is open source under the `Apache v2 License `_. The source code for this driver can be found on `GitHub `_. +**Note:** DataStax products do not support big-endian systems. + Contents -------- :doc:`installation` How to install the driver. :doc:`getting_started` - A guide through the first steps of connecting to Cassandra and executing queries. + A guide through the first steps of connecting to Cassandra and executing queries -:doc:`object_mapper` - Introduction to the integrated object mapper, cqlengine +:doc:`execution_profiles` + An introduction to a more flexible way of configuring request execution -:doc:`api/index` - The API documentation. +:doc:`lwt` + Working with results of conditional requests -:doc:`upgrading` - A guide to upgrading versions of the driver. +:doc:`object_mapper` + Introduction to the integrated object mapper, cqlengine :doc:`performance` Tips for getting good performance. :doc:`query_paging` - Notes on paging large query results. + Notes on paging large query results -:doc:`lwt` - Working with results of conditional requests +:doc:`security` + An overview of the security features of the driver + +:doc:`upgrading` + A guide to upgrading versions of the driver :doc:`user_defined_types` - Working with Cassandra 2.1's user-defined types. + Working with Cassandra 2.1's user-defined types -:doc:`security` - An overview of the security features of the driver. +:doc:`dates_and_times` + Some discussion on the driver's approach to working with timestamp, date, time types + +:doc:`cloud` + A guide to connecting to Datastax Astra + +:doc:`column_encryption` + Transparent client-side per-column encryption and decryption + +:doc:`geo_types` + Working with DSE geometry types + +:doc:`graph` + Graph queries with the Core engine + +:doc:`classic_graph` + Graph queries with the Classic engine + +:doc:`graph_fluent` + DataStax Graph Fluent API + +:doc:`CHANGELOG` + Log of changes to the driver, organized by version. :doc:`faq` A collection of Frequently Asked Questions +:doc:`api/index` + The API documentation. + .. toctree:: :hidden: @@ -52,12 +81,19 @@ Contents installation getting_started upgrading + execution_profiles performance query_paging lwt security user_defined_types object_mapper + geo_types + graph + classic_graph + graph_fluent + dates_and_times + cloud faq Getting Help @@ -66,8 +102,7 @@ Visit the :doc:`FAQ section ` in this documentation. Please send questions to the `mailing list `_. -Alternatively, you can use IRC. Connect to the #datastax-drivers channel on irc.freenode.net. -If you don't have an IRC client, you can use `freenode's web-based client `_. +Alternatively, you can use the `DataStax Community `_. Reporting Issues ---------------- @@ -75,10 +110,3 @@ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. - -Indices and Tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/docs/installation.rst b/docs/installation.rst index 9c5e52ff61..be31551e79 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -3,14 +3,14 @@ Installation Supported Platforms ------------------- -Python 2.6, 2.7, 3.3, and 3.4 are supported. Both CPython (the standard Python +Python 3.9 through 3.13 are supported. Both CPython (the standard Python implementation) and `PyPy `_ are supported and tested. Linux, OSX, and Windows are supported. Installation through pip ------------------------ -`pip `_ is the suggested tool for installing +`pip `_ is the suggested tool for installing packages. It will handle installing all Python dependencies for the driver at the same time as the driver itself. To install the driver*:: @@ -20,13 +20,104 @@ You can use ``pip install --pre cassandra-driver`` if you need to install a beta ***Note**: if intending to use optional extensions, install the `dependencies <#optional-non-python-dependencies>`_ first. The driver may need to be reinstalled if dependencies are added after the initial installation. +Verifying your Installation +--------------------------- +To check if the installation was successful, you can run:: + + python -c 'import cassandra; print(cassandra.__version__)' + +It should print something like "3.29.3". + +.. _installation-datastax-graph: + +(*Optional*) DataStax Graph +--------------------------- +The driver provides an optional fluent graph API that depends on Apache TinkerPop (gremlinpython). It is +not installed by default. To be able to build Gremlin traversals, you need to install +the `graph` extra:: + + pip install cassandra-driver[graph] + +See :doc:`graph_fluent` for more details about this API. + +(*Optional*) Compression Support +-------------------------------- +Compression can optionally be used for communication between the driver and +Cassandra. There are currently two supported compression algorithms: +snappy (in Cassandra 1.2+) and LZ4 (only in Cassandra 2.0+). If either is +available for the driver and Cassandra also supports it, it will +be used automatically. + +For lz4 support:: + + pip install lz4 + +For snappy support:: + + pip install python-snappy + +(If using a Debian Linux derivative such as Ubuntu, it may be easier to +just run ``apt-get install python-snappy``.) + +(*Optional*) Metrics Support +---------------------------- +The driver has built-in support for capturing :attr:`.Cluster.metrics` about +the queries you run. However, the ``scales`` library is required to +support this:: + + pip install scales + +*Optional:* Column-Level Encryption (CLE) Support +-------------------------------------------------- +The driver has built-in support for client-side encryption and +decryption of data. For more, see :doc:`column_encryption`. + +CLE depends on the Python `cryptography `_ module. +When installing Python driver 3.27.0. the `cryptography` module is +also downloaded and installed. +If you are using Python driver 3.28.0 or later and want to use CLE, you must +install the `cryptography `_ module. + +You can install this module along with the driver by specifying the `cle` extra:: + + pip install cassandra-driver[cle] + +Alternatively, you can also install the module directly via `pip`:: + + pip install cryptography + +Any version of cryptography >= 35.0 will work for the CLE feature. You can find additional +details at `PYTHON-1351 `_ + +Speeding Up Installation +^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, installing the driver through ``pip`` uses a pre-compiled, platform-specific wheel when available. +If using a source distribution rather than a wheel, Cython is used to compile certain parts of the driver. +This makes those hot paths faster at runtime, but the Cython compilation +process can take a long time -- as long as 10 minutes in some environments. + +In environments where performance is less important, it may be worth it to +:ref:`disable Cython as documented below `. +You can also use ``CASS_DRIVER_BUILD_CONCURRENCY`` to increase the number of +threads used to build the driver and any C extensions: + +.. code-block:: bash + + $ # installing from source + $ CASS_DRIVER_BUILD_CONCURRENCY=8 python setup.py install + $ # installing from pip + $ CASS_DRIVER_BUILD_CONCURRENCY=8 pip install cassandra-driver + OSX Installation Error ^^^^^^^^^^^^^^^^^^^^^^ If you're installing on OSX and have XCode 5.1 installed, you may see an error like this:: clang: error: unknown argument: '-mno-fused-madd' [-Wunused-command-line-argument-hard-error-in-future] -To fix this, re-run the installation with an extra compilation flag:: +To fix this, re-run the installation with an extra compilation flag: + +.. code-block:: bash ARCHFLAGS=-Wno-error=unused-command-line-argument-hard-error-in-future pip install cassandra-driver @@ -58,41 +149,6 @@ Once the dependencies are installed, simply run:: python setup.py install -Verifying your Installation ---------------------------- -To check if the installation was successful, you can run:: - - python -c 'import cassandra; print cassandra.__version__' - -It should print something like "2.7.0". - -(*Optional*) Compression Support --------------------------------- -Compression can optionally be used for communication between the driver and -Cassandra. There are currently two supported compression algorithms: -snappy (in Cassandra 1.2+) and LZ4 (only in Cassandra 2.0+). If either is -available for the driver and Cassandra also supports it, it will -be used automatically. - -For lz4 support:: - - pip install lz4 - -For snappy support:: - - pip install python-snappy - -(If using a Debian Linux derivative such as Ubuntu, it may be easier to -just run ``apt-get install python-snappy``.) - -(*Optional*) Metrics Support ----------------------------- -The driver has built-in support for capturing :attr:`.Cluster.metrics` about -the queries you run. However, the ``scales`` library is required to -support this:: - - pip install scales - (*Optional*) Non-python Dependencies ------------------------------------ @@ -123,6 +179,8 @@ On OS X, homebrew installations of Python should provide the necessary headers. See :ref:`windows_build` for notes on configuring the build environment on Windows. +.. _cython-extensions: + Cython-based Extensions ~~~~~~~~~~~~~~~~~~~~~~~ By default, this package uses `Cython `_ to optimize core modules and build custom extensions. @@ -153,16 +211,19 @@ If your sudo configuration does not allow SETENV, you must push the option flag applies these options to all dependencies (which break on the custom flag). Therefore, you must first install dependencies, then use install-option:: - sudo pip install six futures + sudo pip install futures sudo pip install --install-option="--no-cython" +Supported Event Loops +^^^^^^^^^^^^^^^^^^^^^ +For Python versions before 3.12 the driver uses the ``asyncore`` module for its default +event loop. Other event loops such as ``libev``, ``gevent`` and ``eventlet`` are also +available via Python modules or C extensions. Python 3.12 has removed ``asyncore`` entirely +so for this platform one of these other event loops must be used. + libev support ^^^^^^^^^^^^^ -The driver currently uses Python's ``asyncore`` module for its default -event loop. For better performance, ``libev`` is also supported through -a C extension. - If you're on Linux, you should be able to install libev through a package manager. For example, on Debian/Ubuntu:: @@ -177,8 +238,10 @@ through `Homebrew `_. For example, on Mac OS X:: $ brew install libev -The libev extension is not built for Windows (the build process is complex, and the Windows implementation uses -select anyway). +The libev extension can now be built for Windows as of Python driver version 3.29.2. You can +install libev using any Windows package manager. For example, to install using `vcpkg `_:: + + $ vcpkg install libev If successful, you should be able to build and install the extension (just using ``setup.py build`` or ``setup.py install``) and then use diff --git a/docs/object_mapper.rst b/docs/object_mapper.rst index 4e38994064..21d2954f4b 100644 --- a/docs/object_mapper.rst +++ b/docs/object_mapper.rst @@ -19,6 +19,9 @@ Contents :doc:`cqlengine/batches` Working with batch mutations +:doc:`cqlengine/connections` + Working with multiple sessions + :ref:`API Documentation ` Index of API documentation @@ -34,6 +37,7 @@ Contents cqlengine/models cqlengine/queryset cqlengine/batches + cqlengine/connections cqlengine/third_party cqlengine/faq @@ -42,60 +46,60 @@ Contents Getting Started --------------- - .. code-block:: python - - import uuid - from cassandra.cqlengine import columns - from cassandra.cqlengine import connection - from datetime import datetime - from cassandra.cqlengine.management import sync_table - from cassandra.cqlengine.models import Model - - #first, define a model - class ExampleModel(Model): - example_id = columns.UUID(primary_key=True, default=uuid.uuid4) - example_type = columns.Integer(index=True) - created_at = columns.DateTime() - description = columns.Text(required=False) - - #next, setup the connection to your cassandra server(s)... - # see http://datastax.github.io/python-driver/api/cassandra/cluster.html for options - # the list of hosts will be passed to create a Cluster() instance - connection.setup(['127.0.0.1'], "cqlengine", protocol_version=3) - - #...and create your CQL table - >>> sync_table(ExampleModel) - - #now we can create some rows: - >>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now()) - >>> em2 = ExampleModel.create(example_type=0, description="example2", created_at=datetime.now()) - >>> em3 = ExampleModel.create(example_type=0, description="example3", created_at=datetime.now()) - >>> em4 = ExampleModel.create(example_type=0, description="example4", created_at=datetime.now()) - >>> em5 = ExampleModel.create(example_type=1, description="example5", created_at=datetime.now()) - >>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now()) - >>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now()) - >>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now()) - - #and now we can run some queries against our table - >>> ExampleModel.objects.count() - 8 - >>> q = ExampleModel.objects(example_type=1) - >>> q.count() - 4 - >>> for instance in q: - >>> print instance.description - example5 - example6 - example7 - example8 - - #here we are applying additional filtering to an existing query - #query objects are immutable, so calling filter returns a new - #query object - >>> q2 = q.filter(example_id=em5.example_id) - - >>> q2.count() - 1 - >>> for instance in q2: - >>> print instance.description - example5 +.. code-block:: python + + import uuid + from cassandra.cqlengine import columns + from cassandra.cqlengine import connection + from datetime import datetime + from cassandra.cqlengine.management import sync_table + from cassandra.cqlengine.models import Model + + #first, define a model + class ExampleModel(Model): + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) + example_type = columns.Integer(index=True) + created_at = columns.DateTime() + description = columns.Text(required=False) + + #next, setup the connection to your cassandra server(s)... + # see https://docs.datastax.com/en/developer/python-driver/latest/api/cassandra/cluster.html for options + # the list of hosts will be passed to create a Cluster() instance + connection.setup(['127.0.0.1'], "cqlengine", protocol_version=3) + + #...and create your CQL table + >>> sync_table(ExampleModel) + + #now we can create some rows: + >>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now()) + >>> em2 = ExampleModel.create(example_type=0, description="example2", created_at=datetime.now()) + >>> em3 = ExampleModel.create(example_type=0, description="example3", created_at=datetime.now()) + >>> em4 = ExampleModel.create(example_type=0, description="example4", created_at=datetime.now()) + >>> em5 = ExampleModel.create(example_type=1, description="example5", created_at=datetime.now()) + >>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now()) + >>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now()) + >>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now()) + + #and now we can run some queries against our table + >>> ExampleModel.objects.count() + 8 + >>> q = ExampleModel.objects(example_type=1) + >>> q.count() + 4 + >>> for instance in q: + >>> print(instance.description) + example5 + example6 + example7 + example8 + + #here we are applying additional filtering to an existing query + #query objects are immutable, so calling filter returns a new + #query object + >>> q2 = q.filter(example_id=em5.example_id) + + >>> q2.count() + 1 + >>> for instance in q2: + >>> print(instance.description) + example5 diff --git a/docs/query_paging.rst b/docs/query_paging.rst index 52366116e8..23ee2c1129 100644 --- a/docs/query_paging.rst +++ b/docs/query_paging.rst @@ -3,7 +3,7 @@ Paging Large Queries ==================== Cassandra 2.0+ offers support for automatic query paging. Starting with -version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is greater than +version 2.0 of the driver, if :attr:`~.Cluster.protocol_version` is greater than :const:`2` (it is by default), queries returning large result sets will be automatically paged. @@ -74,3 +74,22 @@ pages. For example:: handler.finished_event.wait() if handler.error: raise handler.error + +Resume Paged Results +-------------------- + +You can resume the pagination when executing a new query by using the :attr:`.ResultSet.paging_state`. This can be useful if you want to provide some stateless pagination capabilities to your application (ie. via http). For example:: + + from cassandra.query import SimpleStatement + query = "SELECT * FROM users" + statement = SimpleStatement(query, fetch_size=10) + results = session.execute(statement) + + # save the paging_state somewhere and return current results + web_session['paging_state'] = results.paging_state + + + # resume the pagination sometime later... + statement = SimpleStatement(query, fetch_size=10) + ps = web_session['paging_state'] + results = session.execute(statement, paging_state=ps) diff --git a/docs/security.rst b/docs/security.rst index 9f7af68b4d..6dd2624c24 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -27,9 +27,6 @@ For example, suppose Cassandra is setup with its default cluster = Cluster(auth_provider=auth_provider, protocol_version=2) -When working with version 2 or higher of the driver, the protocol -version is set to 2 by default, but we've included it in the example -to be explicit. Custom Authenticators ^^^^^^^^^^^^^^^^^^^^^ @@ -59,14 +56,254 @@ a dict of credentials with a ``username`` and ``password`` key: SSL --- +SSL should be used when client encryption is enabled in Cassandra. + +To give you as much control as possible over your SSL configuration, our SSL +API takes a user-created `SSLContext` instance from the Python standard library. +These docs will include some examples for how to achieve common configurations, +but the `ssl.SSLContext `_ documentation +gives a more complete description of what is possible. + +To enable SSL with version 3.17.0 and higher, you will need to set :attr:`.Cluster.ssl_context` to a +``ssl.SSLContext`` instance to enable SSL. Optionally, you can also set :attr:`.Cluster.ssl_options` +to a dict of options. These will be passed as kwargs to ``ssl.SSLContext.wrap_socket()`` +when new sockets are created. + +If you create your SSLContext using `ssl.create_default_context `_, +be aware that SSLContext.check_hostname is set to True by default, so the hostname validation will be done +by Python and not the driver. For this reason, we need to set the server_hostname at best effort, which is the +resolved ip address. If this validation needs to be done against the FQDN, consider enabling it using the ssl_options +as described in the following examples or implement your own :class:`~.connection.EndPoint` and +:class:`~.connection.EndPointFactory`. + + +The following examples assume you have generated your Cassandra certificate and +keystore files with these intructions: + +* `Setup SSL Cert `_ + +It might be also useful to learn about the different levels of identity verification to understand the examples: + +* `Using SSL in DSE drivers `_ + +SSL with Twisted or Eventlet +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Twisted and Eventlet both use an alternative SSL implementation called pyOpenSSL, so if your `Cluster`'s connection class is +:class:`~cassandra.io.twistedreactor.TwistedConnection` or :class:`~cassandra.io.eventletreactor.EventletConnection`, you must pass a +`pyOpenSSL context `_ instead. +An example is provided in these docs, and more details can be found in the +`documentation `_. +pyOpenSSL is not installed by the driver and must be installed separately. + +SSL Configuration Examples +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Here, we'll describe the server and driver configuration necessary to set up SSL to meet various goals, such as the client verifying the server and the server verifying the client. We'll also include Python code demonstrating how to use servers and drivers configured in these ways. + +.. _ssl-no-identify-verification: + +No identity verification +++++++++++++++++++++++++ + +No identity verification at all. Note that this is not recommended for for production deployments. + +The Cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: false + +The driver configuration: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS + + ssl_context = SSLContext(PROTOCOL_TLS) + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context) + session = cluster.connect() + +.. _ssl-client-verifies-server: + +Client verifies server +++++++++++++++++++++++ + +Ensure the python driver verifies the identity of the server. + +The Cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: false + +For the driver configuration, it's very important to set `ssl_context.verify_mode` +to `CERT_REQUIRED`. Otherwise, the loaded verify certificate will have no effect: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED + + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations('/path/to/rootca.crt') + ssl_context.verify_mode = CERT_REQUIRED + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context) + session = cluster.connect() + +Additionally, you can also force the driver to verify the `hostname` of the server by passing additional options to `ssl_context.wrap_socket` via the `ssl_options` kwarg: + +.. code-block:: python + + from cassandra.cluster import Cluster, Session + from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED + + ssl_context = SSLContext(PROTOCOL_TLS) + ssl_context.load_verify_locations('/path/to/rootca.crt') + ssl_context.verify_mode = CERT_REQUIRED + ssl_context.check_hostname = True + ssl_options = {'server_hostname': '127.0.0.1'} + + cluster = Cluster(['127.0.0.1'], ssl_context=ssl_context, ssl_options=ssl_options) + session = cluster.connect() + +.. _ssl-server-verifies-client: + +Server verifies client +++++++++++++++++++++++ + +If Cassandra is configured to verify clients (``require_client_auth``), you need to generate +SSL key and certificate files. + +The cassandra configuration:: + + client_encryption_options: + enabled: true + keystore: /path/to/127.0.0.1.keystore + keystore_password: myStorePass + require_client_auth: true + truststore: /path/to/dse-truststore.jks + truststore_password: myStorePass + +The Python ``ssl`` APIs require the certificate in PEM format. First, create a certificate +conf file: + +.. code-block:: bash + + cat > gen_client_cert.conf <`_ +for more details about ``SSLContext`` configuration. + +**Server verifies client and client verifies server using Twisted and pyOpenSSL** + +.. code-block:: python + + from OpenSSL import SSL, crypto + from cassandra.cluster import Cluster + from cassandra.io.twistedreactor import TwistedConnection + + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok) + ssl_context.use_certificate_file('/path/to/client.crt_signed') + ssl_context.use_privatekey_file('/path/to/client.key') + ssl_context.load_verify_locations('/path/to/rootca.crt') + + cluster = Cluster( + contact_points=['127.0.0.1'], + connection_class=TwistedConnection, + ssl_context=ssl_context, + ssl_options={'check_hostname': True} + ) + session = cluster.connect() + + +Connecting using Eventlet would look similar except instead of importing and using ``TwistedConnection``, you would +import and use ``EventletConnection``, including the appropriate monkey-patching. + +Versions 3.16.0 and lower +^^^^^^^^^^^^^^^^^^^^^^^^^ + To enable SSL you will need to set :attr:`.Cluster.ssl_options` to a dict of options. These will be passed as kwargs to ``ssl.wrap_socket()`` -when new sockets are created. This should be used when client encryption -is enabled in Cassandra. +when new sockets are created. Note that this use of ssl_options will be +deprecated in the next major release. By default, a ``ca_certs`` value should be supplied (the value should be a string pointing to the location of the CA certs file), and you probably -want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match +want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLS`` to match Cassandra's default protocol. For example: @@ -74,11 +311,111 @@ For example: .. code-block:: python from cassandra.cluster import Cluster - from ssl import PROTOCOL_TLSv1 + from ssl import PROTOCOL_TLS, CERT_REQUIRED - ssl_opts = {'ca_certs': '/path/to/my/ca.certs', - 'ssl_version': PROTOCOL_TLSv1} + ssl_opts = { + 'ca_certs': '/path/to/my/ca.certs', + 'ssl_version': PROTOCOL_TLS, + 'cert_reqs': CERT_REQUIRED # Certificates are required and validated + } cluster = Cluster(ssl_options=ssl_opts) -For further reading, Andrew Mussey has published a thorough guide on +This is only an example to show how to pass the ssl parameters. Consider reading +the `python ssl documentation `_ for +your configuration. For further reading, Andrew Mussey has published a thorough guide on `Using SSL with the DataStax Python driver `_. + +SSL with Twisted +++++++++++++++++ + +In case the twisted event loop is used pyOpenSSL must be installed or an exception will be risen. Also +to set the ``ssl_version`` and ``cert_reqs`` in ``ssl_opts`` the appropriate constants from pyOpenSSL are expected. + +DSE Authentication +------------------ +When authenticating against DSE, the Cassandra driver provides two auth providers that work both with legacy kerberos and Cassandra authenticators, +as well as the new DSE Unified Authentication. This allows client to configure this auth provider independently, +and in advance of any server upgrade. These auth providers are configured in the same way as any previous implementation:: + + from cassandra.auth import DSEGSSAPIAuthProvider + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"]) + cluster = Cluster(auth_provider=auth_provider) + session = cluster.connect() + +Implementations are :attr:`.DSEPlainTextAuthProvider`, :class:`.DSEGSSAPIAuthProvider` and :class:`.SaslAuthProvider`. + +DSE Unified Authentication +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With DSE (>=5.1), unified Authentication allows you to: + +* Proxy Login: Authenticate using a fixed set of authentication credentials but allow authorization of resources based another user id. +* Proxy Execute: Authenticate using a fixed set of authentication credentials but execute requests based on another user id. + +Proxy Login ++++++++++++ + +Proxy login allows you to authenticate with a user but act as another one. You need to ensure the authenticated user has the permission to use the authorization of resources of the other user. ie. this example will allow the `server` user to authenticate as usual but use the authorization of `user1`: + +.. code-block:: text + + GRANT PROXY.LOGIN on role user1 to server + +then you can do the proxy authentication.... + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import SaslAuthProvider + + sasl_kwargs = { + "service": 'dse', + "mechanism":"PLAIN", + "username": 'server', + 'password': 'server', + 'authorization_id': 'user1' + } + + auth_provider = SaslAuthProvider(**sasl_kwargs) + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute(...) # all requests will be executed as 'user1' + +If you are using kerberos, you can use directly :class:`.DSEGSSAPIAuthProvider` and pass the authorization_id, like this: + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import DSEGSSAPIAuthProvider + + # Ensure the kerberos ticket of the server user is set with the kinit utility. + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="server@DATASTAX.COM", + authorization_id='user1@DATASTAX.COM') + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute(...) # all requests will be executed as 'user1' + + +Proxy Execute ++++++++++++++ + +Proxy execute allows you to execute requests as another user than the authenticated one. You need to ensure the authenticated user has the permission to use the authorization of resources of the specified user. ie. this example will allow the `server` user to execute requests as `user1`: + +.. code-block:: text + + GRANT PROXY.EXECUTE on role user1 to server + +then you can do a proxy execute... + +.. code-block:: python + + from cassandra.cluster import Cluster + from cassandra.auth import DSEPlainTextAuthProvider, + + auth_provider = DSEPlainTextAuthProvider('server', 'server') + + c = Cluster(auth_provider=auth_provider) + s = c.connect() + s.execute('select * from k.t;', execute_as='user1') # the request will be executed as 'user1' + +Please see the `official documentation `_ for more details on the feature and configuration process. diff --git a/docs/themes/custom/static/custom.css_t b/docs/themes/custom/static/custom.css_t index 00d9b2cf62..c3460e75a5 100644 --- a/docs/themes/custom/static/custom.css_t +++ b/docs/themes/custom/static/custom.css_t @@ -1,4 +1,21 @@ -@import url("sphinxdoc.css"); +@import url("alabaster.css"); + +div.document { + width: 1200px; +} + +div.sphinxsidebar h1.logo a { + font-size: 24px; +} + +code.descname { + color: #4885ed; +} + +th.field-name { + min-width: 100px; + color: #3cba54; +} div.versionmodified { font-weight: bold diff --git a/docs/themes/custom/theme.conf b/docs/themes/custom/theme.conf index f4b51356f2..b0fbb6961e 100644 --- a/docs/themes/custom/theme.conf +++ b/docs/themes/custom/theme.conf @@ -1,4 +1,11 @@ [theme] -inherit = sphinxdoc +inherit = alabaster stylesheet = custom.css pygments_style = friendly + +[options] +description = Python driver for Cassandra +github_user = datastax +github_repo = python-driver +github_button = true +github_type = star \ No newline at end of file diff --git a/docs/upgrading.rst b/docs/upgrading.rst index 2fa86b8fc7..3fd937d7bc 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -4,6 +4,85 @@ Upgrading .. toctree:: :maxdepth: 1 +Upgrading from dse-driver +------------------------- + +Since 3.21.0, cassandra-driver fully supports DataStax products. dse-driver and +dse-graph users should now migrate to cassandra-driver to benefit from latest bug fixes +and new features. The upgrade to this new unified driver version is straightforward +with no major API changes. + +Installation +^^^^^^^^^^^^ + +Only the `cassandra-driver` package should be installed. `dse-driver` and `dse-graph` +are not required anymore:: + + pip install cassandra-driver + +If you need the Graph *Fluent* API (features provided by dse-graph):: + + pip install cassandra-driver[graph] + +See :doc:`installation` for more details. + +Import from the cassandra module +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +There is no `dse` module, so you should import from the `cassandra` module. You +need to change only the first module of your import statements, not the submodules. + +.. code-block:: python + + from dse.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from dse.auth import PlainTextAuthProvider + from dse.policies import WhiteListRoundRobinPolicy + + # becomes + + from cassandra.cluster import Cluster, EXEC_PROFILE_GRAPH_DEFAULT + from cassandra.auth import PlainTextAuthProvider + from cassandra.policies import WhiteListRoundRobinPolicy + +Also note that the cassandra.hosts module doesn't exist in cassandra-driver. This +module is named cassandra.pool. + +dse-graph +^^^^^^^^^ + +dse-graph features are now built-in in cassandra-driver. The only change you need +to do is your import statements: + +.. code-block:: python + + from dse_graph import .. + from dse_graph.query import .. + + # becomes + + from cassandra.datastax.graph.fluent import .. + from cassandra.datastax.graph.fluent.query import .. + +See :mod:`~.datastax.graph.fluent`. + +Session.execute and Session.execute_async API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Although it is not common to use this API with positional arguments, it is +important to be aware that the `host` and `execute_as` parameters have had +their positional order swapped. This is only because `execute_as` was added +in dse-driver before `host`. + +See :meth:`.Session.execute`. + +Deprecations +^^^^^^^^^^^^ + +These changes are optional, but recommended: + +* Importing from `cassandra.graph` is deprecated. Consider importing from `cassandra.datastax.graph`. +* Use :class:`~.policies.DefaultLoadBalancingPolicy` instead of DSELoadBalancingPolicy. + Upgrading to 3.0 ---------------- Version 3.0 of the DataStax Python driver for Apache Cassandra @@ -46,9 +125,16 @@ materialize a list using the iterator: results = session.execute("SELECT * FROM system.local") row_list = list(results) -For backward compatability, :class:`~.ResultSet` supports indexing. If -the result is paged, all pages will be materialized. A warning will -be logged if a paged query is implicitly materialized. +For backward compatibility, :class:`~.ResultSet` supports indexing. When +accessed at an index, a `~.ResultSet` object will materialize all its pages: + +.. code-block:: python + + results = session.execute("SELECT * FROM system.local") + first_result = results[0] # materializes results, fetching all pages + +This can send requests and load (possibly large) results into memory, so +`~.ResultSet` will log a warning on implicit materialization. Trace information is not attached to executed Statements ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -250,7 +336,7 @@ See :ref:`query-paging` for full details. Protocol-Level Batch Statements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ With version 1 of the native protocol, batching of statements required -using a `BATCH cql query `_. +using a `BATCH cql query `_. With version 2 of the native protocol, you can now batch statements at the protocol level. This allows you to use many different prepared statements within a single batch. @@ -296,7 +382,3 @@ The following dependencies have officially been made optional: * ``scales`` * ``blist`` - -And one new dependency has been added (to enable Python 3 support): - -* ``six`` diff --git a/docs/user_defined_types.rst b/docs/user_defined_types.rst index fd95b09fc4..32c03e37e8 100644 --- a/docs/user_defined_types.rst +++ b/docs/user_defined_types.rst @@ -9,12 +9,16 @@ new type through ``CREATE TYPE`` statements in CQL:: Version 2.1 of the Python driver adds support for user-defined types. -Registering a Class to Map to a UDT ------------------------------------ +Registering a UDT +----------------- You can tell the Python driver to return columns of a specific UDT as -instances of a class by registering them with your :class:`~.Cluster` +instances of a class or a dict by registering them with your :class:`~.Cluster` instance through :meth:`.Cluster.register_user_type`: + +Map a Class to a UDT +++++++++++++++++++++ + .. code-block:: python cluster = Cluster(protocol_version=3) @@ -39,7 +43,29 @@ instance through :meth:`.Cluster.register_user_type`: # results will include Address instances results = session.execute("SELECT * FROM users") row = results[0] - print row.id, row.location.street, row.location.zipcode + print(row.id, row.location.street, row.location.zipcode) + +Map a dict to a UDT ++++++++++++++++++++ + +.. code-block:: python + + cluster = Cluster(protocol_version=3) + session = cluster.connect() + session.set_keyspace('mykeyspace') + session.execute("CREATE TYPE address (street text, zipcode int)") + session.execute("CREATE TABLE users (id int PRIMARY KEY, location frozen
)") + + cluster.register_user_type('mykeyspace', 'address', dict) + + # insert a row using a prepared statement and a tuple + insert_statement = session.prepare("INSERT INTO mykeyspace.users (id, location) VALUES (?, ?)") + session.execute(insert_statement, [0, ("123 Main St.", 78723)]) + + # results will include dict instances + results = session.execute("SELECT * FROM users") + row = results[0] + print(row.id, row.location['street'], row.location['zipcode']) Using UDTs Without Registering Them ----------------------------------- @@ -79,7 +105,7 @@ for the UDT: results = session.execute("SELECT * FROM users") first_row = results[0] address = first_row.location - print address # prints "Address(street='123 Main St.', zipcode=78723)" + print(address) # prints "Address(street='123 Main St.', zipcode=78723)" street = address.street zipcode = address.street diff --git a/example_core.py b/example_core.py index 3235c79bad..56e8924d1d 100644 --- a/example_core.py +++ b/example_core.py @@ -1,12 +1,14 @@ #!/usr/bin/env python -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/example_mapper.py b/example_mapper.py index 4d5ebca361..8105dbe2b1 100755 --- a/example_mapper.py +++ b/example_mapper.py @@ -1,12 +1,14 @@ #!/usr/bin/env python -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -76,7 +78,10 @@ def main(): try: nick.iff(birth_year=1988).update(birth_year=1989) except LWTException: - print "precondition not met" + print("precondition not met") + + log.info("### setting individual column to NULL by updating it to None") + nick.update(birth_year=None) # showing validation try: @@ -96,15 +101,15 @@ def main(): log.info("### All members") for m in FamilyMembers.all(): - print m, m.birth_year, m.sex + print(m, m.birth_year, m.sex) log.info("### Select by partition key") for m in FamilyMembers.objects(id=simmons.id): - print m, m.birth_year, m.sex + print(m, m.birth_year, m.sex) log.info("### Constrain on clustering key") for m in FamilyMembers.objects(id=simmons.id, surname=simmons.surname): - print m, m.birth_year, m.sex + print(m, m.birth_year, m.sex) log.info("### Constrain on clustering key") kids = FamilyMembers.objects(id=simmons.id, surname=simmons.surname, name__in=['Nick', 'Sophie']) @@ -112,7 +117,7 @@ def main(): log.info("### Delete a record") FamilyMembers(id=hogan_id, surname='Hogan', name='Linda').delete() for m in FamilyMembers.objects(id=hogan_id): - print m, m.birth_year, m.sex + print(m, m.birth_year, m.sex) management.drop_keyspace(KEYSPACE) diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 0000000000..889f911132 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,8 @@ +Driver Examples +=============== +This directory will contain a set of scripts demonstrating driver APIs or integration techniques. It will not be exhaustive, but will contain examples where they are too involved, or +open-ended to include inline in the docstrings. In that case, they should be referenced from the docstrings + +Features +-------- +* `request_init_listener.py `_ A script demonstrating how to register a session request listener and use it to track alternative metrics about requests (size, for example). diff --git a/examples/concurrent_executions/execute_async_with_queue.py b/examples/concurrent_executions/execute_async_with_queue.py new file mode 100644 index 0000000000..44a91a530c --- /dev/null +++ b/examples/concurrent_executions/execute_async_with_queue.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inserts multiple rows in a table asynchronously, limiting the amount +of parallel requests with a Queue. +""" + +import time +import uuid +import queue + +from cassandra.cluster import Cluster + + +CONCURRENCY_LEVEL = 32 +TOTAL_QUERIES = 10000 + +cluster = Cluster() +session = cluster.connect() + +session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " + "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) +session.execute("USE examples") +session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") +prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") + + +def clear_queue(): + while True: + try: + futures.get_nowait().result() + except queue.Empty: + break + + +start = time.time() +futures = queue.Queue(maxsize=CONCURRENCY_LEVEL) + +# Chunking way, when the max concurrency level is reached, we +# wait the current chunk of requests to finish +for i in range(TOTAL_QUERIES): + future = session.execute_async(prepared_insert, (uuid.uuid4(), str(i))) + try: + futures.put_nowait(future) + except queue.Full: + clear_queue() + futures.put_nowait(future) + +clear_queue() +end = time.time() + +print("Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.". + format(TOTAL_QUERIES, CONCURRENCY_LEVEL, (end-start))) diff --git a/examples/concurrent_executions/execute_with_threads.py b/examples/concurrent_executions/execute_with_threads.py new file mode 100644 index 0000000000..69126de6ec --- /dev/null +++ b/examples/concurrent_executions/execute_with_threads.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inserts multiple rows in a table limiting the amount of parallel requests. + +Note that the driver also provide convenient utility functions to accomplish this. +See https://docs.datastax.com/en/developer/python-driver/latest/api/cassandra/concurrent/ +""" + +import time +import uuid +import threading +from cassandra.cluster import Cluster + + +CONCURRENCY_LEVEL = 32 +TOTAL_QUERIES = 10000 +COUNTER = 0 +COUNTER_LOCK = threading.Lock() + +cluster = Cluster() +session = cluster.connect() + +session.execute(("CREATE KEYSPACE IF NOT EXISTS examples " + "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1' }")) +session.execute("USE examples") +session.execute("CREATE TABLE IF NOT EXISTS tbl_sample_kv (id uuid, value text, PRIMARY KEY (id))") +prepared_insert = session.prepare("INSERT INTO tbl_sample_kv (id, value) VALUES (?, ?)") + + +class SimpleQueryExecutor(threading.Thread): + + def run(self): + global COUNTER + + while True: + with COUNTER_LOCK: + current = COUNTER + COUNTER += 1 + + if current >= TOTAL_QUERIES: + break + + session.execute(prepared_insert, (uuid.uuid4(), str(current))) + + +# Launch in parallel n async operations (n being the concurrency level) +start = time.time() +threads = [] +for i in range(CONCURRENCY_LEVEL): + t = SimpleQueryExecutor() + threads.append(t) + t.start() + +for thread in threads: + thread.join() +end = time.time() + +print("Finished executing {} queries with a concurrency level of {} in {:.2f} seconds.". + format(TOTAL_QUERIES, CONCURRENCY_LEVEL, (end-start))) diff --git a/examples/request_init_listener.py b/examples/request_init_listener.py new file mode 100644 index 0000000000..6cce4953e2 --- /dev/null +++ b/examples/request_init_listener.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script shows an example "request init listener" which can be registered to track certain request metrics +# for a session. In this case we're just accumulating total request and error counts, as well as some statistics +# about the encoded request size. Note that the counts would be available using the internal 'metrics' tracking -- +# this is just demonstrating a way to track a few custom attributes. + +from cassandra.cluster import Cluster +from greplin import scales + +import pprint +pp = pprint.PrettyPrinter(indent=2) + + +class RequestAnalyzer(object): + """ + Class used to track request and error counts for a Session. + + Also computes statistics on encoded request size. + """ + + requests = scales.PmfStat('request size') + errors = scales.IntStat('errors') + + def __init__(self, session): + scales.init(self, '/cassandra') + # each instance will be registered with a session, and receive a callback for each request generated + session.add_request_init_listener(self.on_request) + + def on_request(self, rf): + # This callback is invoked each time a request is created, on the thread creating the request. + # We can use this to count events, or add callbacks + rf.add_callbacks(self.on_success, self.on_error, callback_args=(rf,), errback_args=(rf,)) + + def on_success(self, _, response_future): + # future callback on a successful request; just record the size + self.requests.addValue(response_future.request_encoded_size) + + def on_error(self, _, response_future): + # future callback for failed; record size and increment errors + self.requests.addValue(response_future.request_encoded_size) + self.errors += 1 + + def __str__(self): + # just extracting request count from the size stats (which are recorded on all requests) + request_sizes = dict(self.requests) + count = request_sizes.pop('count') + return "%d requests (%d errors)\nRequest size statistics:\n%s" % (count, self.errors, pp.pformat(request_sizes)) + + +# connect a session +session = Cluster().connect() + +# attach a listener to this session +ra = RequestAnalyzer(session) + +session.execute("SELECT release_version FROM system.local") +session.execute("SELECT release_version FROM system.local") + +print(ra) +# 2 requests (0 errors) +# Request size statistics: +# { '75percentile': 74, +# '95percentile': 74, +# '98percentile': 74, +# '999percentile': 74, +# '99percentile': 74, +# 'max': 74, +# 'mean': 74.0, +# 'median': 74.0, +# 'min': 74, +# 'stddev': 0.0} + +try: + # intentional error to show that count increase + session.execute("syntax err") +except Exception as e: + pass + +print() +print(ra) # note: the counts are updated, but the stats are not because scales only updates every 20s +# 3 requests (1 errors) +# Request size statistics: +# { '75percentile': 74, +# '95percentile': 74, +# '98percentile': 74, +# '999percentile': 74, +# '99percentile': 74, +# 'max': 74, +# 'mean': 74.0, +# 'median': 74.0, +# 'min': 74, +# 'stddev': 0.0} diff --git a/ez_setup.py b/ez_setup.py index 2535472190..76e71057f0 100644 --- a/ez_setup.py +++ b/ez_setup.py @@ -20,6 +20,7 @@ import tarfile import optparse import subprocess +from urllib.request import urlopen from distutils import log @@ -148,10 +149,6 @@ def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, """ # making sure we use the absolute path to_dir = os.path.abspath(to_dir) - try: - from urllib.request import urlopen - except ImportError: - from urllib2 import urlopen tgz_name = "setuptools-%s.tar.gz" % version url = download_base + tgz_name saveto = os.path.join(to_dir, tgz_name) @@ -197,13 +194,7 @@ def _extractall(self, path=".", members=None): self.extract(tarinfo, path) # Reverse sort directories. - if sys.version_info < (2, 4): - def sorter(dir1, dir2): - return cmp(dir1.name, dir2.name) - directories.sort(sorter) - directories.reverse() - else: - directories.sort(key=operator.attrgetter('name'), reverse=True) + directories.sort(key=operator.attrgetter('name'), reverse=True) # Set correct owner, mtime and filemode on directories. for tarinfo in directories: diff --git a/requirements.txt b/requirements.txt index 54bd98a1c2..1d5f0bcfc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1 @@ -six >=1.6 -futures <=2.2.0 -# Futures is not required for Python 3, but it works up through 2.2.0 (after which it introduced breaking syntax). -# This is left here to make sure install -r works with any runtime. When installing via setup.py, futures is omitted -# for Python 3, in favor of the standard library implementation. -# see PYTHON-393 +geomet>=1.1 diff --git a/setup.py b/setup.py index a6fcc7b4b2..5144e90501 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,21 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function import os import sys import warnings -if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": - print("Running gevent tests") - from gevent.monkey import patch_all - patch_all() - -if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": - print("Running eventlet tests") - from eventlet import monkey_patch - monkey_patch() - import ez_setup ez_setup.use_setuptools() @@ -37,8 +28,6 @@ DistutilsExecError) from distutils.cmd import Command -PY3 = sys.version_info[0] == 3 - try: import subprocess has_subprocess = True @@ -51,19 +40,6 @@ with open("README.rst") as f: long_description = f.read() - -try: - from nose.commands import nosetests -except ImportError: - gevent_nosetests = None - eventlet_nosetests = None -else: - class gevent_nosetests(nosetests): - description = "run nosetests with gevent monkey patching" - - class eventlet_nosetests(nosetests): - description = "run nosetests with eventlet monkey patching" - has_cqlengine = False if __name__ == '__main__' and sys.argv[1] == "install": try: @@ -100,7 +76,7 @@ def run(self): try: os.makedirs(path) - except: + except OSError: pass if has_subprocess: @@ -108,7 +84,7 @@ def run(self): # http://docs.cython.org/src/userguide/special_methods.html#docstrings import glob for f in glob.glob("cassandra/*.so"): - print("Removing '%s' to allow docs to run on pure python modules." %(f,)) + print("Removing '%s' to allow docs to run on pure python modules." % (f,)) os.unlink(f) # Build io extension to make import and docstrings work @@ -141,14 +117,30 @@ def __init__(self, ext): self.ext = ext +is_windows = sys.platform.startswith('win32') +is_macos = sys.platform.startswith('darwin') + murmur3_ext = Extension('cassandra.cmurmur3', sources=['cassandra/cmurmur3.c']) + +def eval_env_var_as_array(varname): + val = os.environ.get(varname) + return None if not val else [v.strip() for v in val.split(',')] + + +DEFAULT_LIBEV_INCLUDES = ['/usr/include/libev', '/usr/local/include', '/opt/local/include', '/usr/include'] +DEFAULT_LIBEV_LIBDIRS = ['/usr/local/lib', '/opt/local/lib', '/usr/lib64'] +libev_includes = eval_env_var_as_array('CASS_DRIVER_LIBEV_INCLUDES') or DEFAULT_LIBEV_INCLUDES +libev_libdirs = eval_env_var_as_array('CASS_DRIVER_LIBEV_LIBS') or DEFAULT_LIBEV_LIBDIRS +if is_macos: + libev_includes.extend(['/opt/homebrew/include', os.path.expanduser('~/homebrew/include')]) + libev_libdirs.extend(['/opt/homebrew/lib']) libev_ext = Extension('cassandra.io.libevwrapper', sources=['cassandra/io/libevwrapper.c'], - include_dirs=['/usr/include/libev', '/usr/local/include', '/opt/local/include'], + include_dirs=libev_includes, libraries=['ev'], - library_dirs=['/usr/local/lib', '/opt/local/lib']) + library_dirs=libev_libdirs) platform_unsupported_msg = \ """ @@ -171,8 +163,6 @@ def __init__(self, ext): ================================================================================= """ -is_windows = os.name == 'nt' - is_pypy = "PyPy" in sys.version if is_pypy: sys.stderr.write(pypy_unsupported_msg) @@ -186,7 +176,7 @@ def __init__(self, ext): try_extensions = "--no-extensions" not in sys.argv and is_supported_platform and is_supported_arch and not os.environ.get('CASS_DRIVER_NO_EXTENSIONS') try_murmur3 = try_extensions and "--no-murmur3" not in sys.argv -try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not is_windows +try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_LIBEV') try_cython = try_extensions and "--no-cython" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_CYTHON') try_cython &= 'egg_info' not in sys.argv # bypass setup_requires for pip egg_info calls, which will never have --install-option"--no-cython" coming fomr pip @@ -212,7 +202,7 @@ def __init__(self, *args, **kwargs): base.__init__(self, *args, **kwargs) else: Extension.__init__(self, *args, **kwargs) - + class build_extensions(build_ext): @@ -301,6 +291,7 @@ def _setup_extensions(self): self.extensions.append(murmur3_ext) if try_libev: + sys.stderr.write("Appending libev extension %s" % libev_ext) self.extensions.append(libev_ext) if try_cython: @@ -348,6 +339,13 @@ def pre_build_check(): compiler = new_compiler(compiler=be.compiler) customize_compiler(compiler) + try: + # We must be able to initialize the compiler if it has that method + if hasattr(compiler, "initialize"): + compiler.initialize() + except OSError: + return False + executables = [] if compiler.compiler_type in ('unix', 'cygwin'): executables = [compiler.executables[exe][0] for exe in ('compiler_so', 'linker_so')] @@ -373,12 +371,6 @@ def pre_build_check(): def run_setup(extensions): kw = {'cmdclass': {'doc': DocCommand}} - if gevent_nosetests is not None: - kw['cmdclass']['gevent_nosetests'] = gevent_nosetests - - if eventlet_nosetests is not None: - kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests - kw['cmdclass']['build_ext'] = build_extensions kw['ext_modules'] = [Extension('DUMMY', [])] # dummy extension makes sure build_ext is called for install @@ -388,28 +380,45 @@ def run_setup(extensions): # 1.) build_ext eats errors at compile time, letting the install complete while producing useful feedback # 2.) there could be a case where the python environment has cython installed but the system doesn't have build tools if pre_build_check(): - kw['setup_requires'] = ['Cython>=0.20'] + cython_dep = 'Cython>=3.0' + user_specified_cython_version = os.environ.get('CASS_DRIVER_ALLOWED_CYTHON_VERSION') + if user_specified_cython_version is not None: + cython_dep = 'Cython==%s' % (user_specified_cython_version,) + kw['setup_requires'] = [cython_dep] else: sys.stderr.write("Bypassing Cython setup requirement\n") - dependencies = ['six >=1.6'] + dependencies = ['geomet>=1.1'] - if not PY3: - dependencies.append('futures') + _EXTRAS_REQUIRE = { + 'graph': ['gremlinpython==3.4.6'], + 'cle': ['cryptography>=42.0'] + } setup( name='cassandra-driver', version=__version__, - description='Python driver for Cassandra', + description='Apache Cassandra Python Driver', long_description=long_description, + long_description_content_type='text/x-rst', url='http://github.com/datastax/python-driver', - author='Tyler Hobbs', - author_email='tyler@datastax.com', - packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine'], - keywords='cassandra,cql,orm', + project_urls={ + 'Documentation': 'https://docs.datastax.com/en/developer/python-driver/latest/', + 'Source': 'https://github.com/datastax/python-driver/', + 'Issues': 'https://datastax-oss.atlassian.net/browse/PYTHON', + }, + author='DataStax', + packages=[ + 'cassandra', 'cassandra.io', 'cassandra.cqlengine', 'cassandra.graph', + 'cassandra.datastax', 'cassandra.datastax.insights', 'cassandra.datastax.graph', + 'cassandra.datastax.graph.fluent', 'cassandra.datastax.cloud', + "cassandra.column_encryption" + ], + keywords='cassandra,cql,orm,dse,graph', include_package_data=True, install_requires=dependencies, - tests_require=['nose', 'mock<=1.0.1', 'PyYAML', 'pytz', 'sure'], + extras_require=_EXTRAS_REQUIRE, + tests_require=['pytest', 'PyYAML', 'pytz'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', @@ -417,16 +426,18 @@ def run_setup(extensions): 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', - 'Programming Language :: Python :: 2.6', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries :: Python Modules' ], **kw) + run_setup(None) if has_cqlengine: diff --git a/test-datastax-requirements.txt b/test-datastax-requirements.txt new file mode 100644 index 0000000000..d605f6dc51 --- /dev/null +++ b/test-datastax-requirements.txt @@ -0,0 +1,4 @@ +-r test-requirements.txt +kerberos +gremlinpython==3.4.6 +cryptography >= 42.0 diff --git a/test-requirements.txt b/test-requirements.txt index 4c917da6c6..513451b496 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,14 +1,13 @@ -r requirements.txt scales -nose -mock<=1.0.1 -ccm>=2.0 -unittest2 -PyYAML +pytest +ccm>=3.1.5 pytz -sure pure-sasl -twisted -gevent>=1.0 +twisted[tls] +gevent eventlet -cython>=0.21 +cython>=3.0 +packaging +futurist +asynctest diff --git a/tests/__init__.py b/tests/__init__.py index abfb8bf792..7799b51399 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,9 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest import logging import sys import socket +import platform +import os +from concurrent.futures import ThreadPoolExecutor + +from cassandra import DependencyException log = logging.getLogger() log.setLevel('DEBUG') @@ -28,9 +36,12 @@ def is_eventlet_monkey_patched(): if 'eventlet.patcher' not in sys.modules: return False - import eventlet.patcher - return eventlet.patcher.is_monkey_patched('socket') - + try: + import eventlet.patcher + return eventlet.patcher.is_monkey_patched('socket') + # Yet another case related to PYTHON-1364 + except AttributeError: + return False def is_gevent_monkey_patched(): if 'gevent.monkey' not in sys.modules: @@ -41,3 +52,65 @@ def is_gevent_monkey_patched(): def is_monkey_patched(): return is_gevent_monkey_patched() or is_eventlet_monkey_patched() + +MONKEY_PATCH_LOOP = bool(os.getenv('MONKEY_PATCH_LOOP', False)) +EVENT_LOOP_MANAGER = os.getenv('EVENT_LOOP_MANAGER', "libev") + + +# If set to to true this will force the Cython tests to run regardless of whether they are installed +cython_env = os.getenv('VERIFY_CYTHON', "False") + +VERIFY_CYTHON = False +if(cython_env == 'True'): + VERIFY_CYTHON = True + +thread_pool_executor_class = ThreadPoolExecutor + +if "gevent" in EVENT_LOOP_MANAGER: + import gevent.monkey + gevent.monkey.patch_all() + from cassandra.io.geventreactor import GeventConnection + connection_class = GeventConnection +elif "eventlet" in EVENT_LOOP_MANAGER: + from eventlet import monkey_patch + monkey_patch() + + from cassandra.io.eventletreactor import EventletConnection + connection_class = EventletConnection + + try: + from futurist import GreenThreadPoolExecutor + thread_pool_executor_class = GreenThreadPoolExecutor + except: + # futurist is installed only with python >=3.7 + pass +elif "asyncore" in EVENT_LOOP_MANAGER: + from cassandra.io.asyncorereactor import AsyncoreConnection + connection_class = AsyncoreConnection +elif "twisted" in EVENT_LOOP_MANAGER: + from cassandra.io.twistedreactor import TwistedConnection + connection_class = TwistedConnection +elif "asyncio" in EVENT_LOOP_MANAGER: + from cassandra.io.asyncioreactor import AsyncioConnection + connection_class = AsyncioConnection +else: + log.debug("Using default event loop (libev)") + try: + from cassandra.io.libevreactor import LibevConnection + connection_class = LibevConnection + except DependencyException as e: + log.debug('Could not import LibevConnection, ' + 'using connection_class=None; ' + 'failed with error:\n {}'.format( + repr(e) + )) + log.debug("Will attempt to set connection class at cluster initialization") + connection_class = None + + +def is_windows(): + return "Windows" in platform.system() + + +notwindows = unittest.skipUnless(not is_windows(), "This test is not adequate for windows") +notpypy = unittest.skipUnless(not platform.python_implementation() == 'PyPy', "This tests is not suitable for pypy") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 4b3afaa045..3b0103db31 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,28 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import re +import os +from cassandra.cluster import Cluster + +from tests import connection_class, EVENT_LOOP_MANAGER +Cluster.connection_class = connection_class +import unittest + +from packaging.version import Version import logging -import os import socket import sys import time import traceback +import platform from threading import Event from subprocess import call from itertools import groupby +import shutil -from cassandra import OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure, AlreadyExists -from cassandra.cluster import Cluster +from cassandra import OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure, AlreadyExists,\ + InvalidRequest from cassandra.protocol import ConfigurationException +from cassandra import ProtocolVersion try: - from ccmlib.cluster import Cluster as CCMCluster from ccmlib.dse_cluster import DseCluster + from ccmlib.hcd_cluster import HcdCluster + from ccmlib.cluster import Cluster as CCMCluster from ccmlib.cluster_factory import ClusterFactory as CCMClusterFactory from ccmlib import common except ImportError as e: @@ -45,6 +55,14 @@ SINGLE_NODE_CLUSTER_NAME = 'single_node' MULTIDC_CLUSTER_NAME = 'multidc_test_cluster' +# When use_single_interface is specified ccm will assign distinct port numbers to each +# node in the cluster. This value specifies the default port value used for the first +# node that comes up. +# +# TODO: In the future we may want to make this configurable, but this should only apply +# if a non-standard port were specified when starting up the cluster. +DEFAULT_SINGLE_INTERFACE_PORT=9046 + CCM_CLUSTER = None path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'ccm') @@ -65,7 +83,7 @@ def get_server_versions(): if cass_version is not None: return (cass_version, cql_version) - c = Cluster() + c = TestCluster() s = c.connect() row = s.execute('SELECT cql_version, release_version FROM system.local')[0] @@ -84,41 +102,105 @@ def _tuple_version(version_string): return tuple([int(p) for p in version_string.split('.')]) -USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False)) - -default_cassandra_version = '2.2.0' +def cmd_line_args_to_dict(env_var): + cmd_args_env = os.environ.get(env_var, None) + args = {} + if cmd_args_env: + cmd_args = cmd_args_env.strip().split(' ') + while cmd_args: + cmd_arg = cmd_args.pop(0) + cmd_arg_value = True if cmd_arg.startswith('--') else cmd_args.pop(0) + args[cmd_arg.lstrip('-')] = cmd_arg_value + return args def _get_cass_version_from_dse(dse_version): if dse_version.startswith('4.6') or dse_version.startswith('4.5'): - cass_ver = "2.0" + raise Exception("Cassandra Version 2.0 not supported anymore") elif dse_version.startswith('4.7') or dse_version.startswith('4.8'): cass_ver = "2.1" elif dse_version.startswith('5.0'): cass_ver = "3.0" + elif dse_version.startswith('5.1'): + # TODO: refactor this method to use packaging.Version everywhere + if Version(dse_version) >= Version('5.1.2'): + cass_ver = "3.11" + else: + cass_ver = "3.10" + elif dse_version.startswith('6.0'): + if dse_version == '6.0.0': + cass_ver = '4.0.0.2284' + elif dse_version == '6.0.1': + cass_ver = '4.0.0.2349' + else: + cass_ver = '4.0.0.' + ''.join(dse_version.split('.')) + elif Version(dse_version) >= Version('6.7'): + if dse_version == '6.7.0': + cass_ver = "4.0.0.67" + else: + cass_ver = '4.0.0.' + ''.join(dse_version.split('.')) + elif dse_version.startswith('6.8'): + if dse_version == '6.8.0': + cass_ver = "4.0.0.68" + else: + cass_ver = '4.0.0.' + ''.join(dse_version.split('.')) else: - log.error("Uknown dse version found {0}, defaulting to 2.1".format(dse_version)) + log.error("Unknown dse version found {0}, defaulting to 2.1".format(dse_version)) cass_ver = "2.1" + return Version(cass_ver) - return cass_ver -CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None) -DSE_VERSION = os.getenv('DSE_VERSION', None) -DSE_CRED = os.getenv('DSE_CREDS', None) -if DSE_VERSION: - CASSANDRA_VERSION = _get_cass_version_from_dse(DSE_VERSION) -else: - CASSANDRA_VERSION = os.getenv('CASSANDRA_VERSION', default_cassandra_version) +def _get_cass_version_from_hcd(hcd_version): + return Version("4.0.11") -CCM_KWARGS = {} -if CASSANDRA_DIR: - log.info("Using Cassandra dir: %s", CASSANDRA_DIR) - CCM_KWARGS['install_dir'] = CASSANDRA_DIR -else: - log.info('Using Cassandra version: %s', CASSANDRA_VERSION) - CCM_KWARGS['version'] = CASSANDRA_VERSION +def _get_dse_version_from_cass(cass_version): + if cass_version.startswith('2.1'): + dse_ver = "4.8.15" + elif cass_version.startswith('3.0'): + dse_ver = "5.0.12" + elif cass_version.startswith('3.10') or cass_version.startswith('3.11'): + dse_ver = "5.1.7" + elif cass_version.startswith('4.0'): + dse_ver = "6.0" + else: + log.error("Unknown cassandra version found {0}, defaulting to 2.1".format(cass_version)) + dse_ver = "2.1" + return dse_ver + +USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False)) +KEEP_TEST_CLUSTER = bool(os.getenv('KEEP_TEST_CLUSTER', False)) +SIMULACRON_JAR = os.getenv('SIMULACRON_JAR', None) +CLOUD_PROXY_PATH = os.getenv('CLOUD_PROXY_PATH', None) + +# Supported Clusters: Cassandra, DDAC, DSE, HCD +DSE_VERSION = None +HCD_VERSION = None +if os.getenv('DSE_VERSION', None): # we are testing against DSE + DSE_VERSION = Version(os.getenv('DSE_VERSION', None)) + DSE_CRED = os.getenv('DSE_CREDS', None) + CASSANDRA_VERSION = _get_cass_version_from_dse(DSE_VERSION.base_version) + CCM_VERSION = DSE_VERSION.base_version +elif os.getenv('HCD_VERSION', None): # we are testing against HCD + HCD_VERSION = Version(os.getenv('HCD_VERSION', None)) + CASSANDRA_VERSION = _get_cass_version_from_hcd(HCD_VERSION.base_version) + CCM_VERSION = HCD_VERSION.base_version +else: # we are testing against Cassandra or DDAC + cv_string = os.getenv('CASSANDRA_VERSION', None) + mcv_string = os.getenv('MAPPED_CASSANDRA_VERSION', None) + try: + cassandra_version = Version(cv_string) # env var is set to test-dse for DDAC + except: + # fallback to MAPPED_CASSANDRA_VERSION + cassandra_version = Version(mcv_string) + + CASSANDRA_VERSION = Version(mcv_string) if mcv_string else cassandra_version + CCM_VERSION = mcv_string if mcv_string else cv_string + +CASSANDRA_IP = os.getenv('CLUSTER_IP', '127.0.0.1') +CASSANDRA_DIR = os.getenv('CASSANDRA_DIR', None) +CCM_KWARGS = {} if DSE_VERSION: log.info('Using DSE version: %s', DSE_VERSION) if not CASSANDRA_DIR: @@ -126,33 +208,172 @@ def _get_cass_version_from_dse(dse_version): if DSE_CRED: log.info("Using DSE credentials file located at {0}".format(DSE_CRED)) CCM_KWARGS['dse_credentials_file'] = DSE_CRED +elif HCD_VERSION: + log.info('Using HCD version: %s', HCD_VERSION) + CCM_KWARGS['version'] = HCD_VERSION +elif CASSANDRA_DIR: + log.info("Using Cassandra dir: %s", CASSANDRA_DIR) + CCM_KWARGS['install_dir'] = CASSANDRA_DIR +else: + log.info('Using Cassandra version: %s', CCM_VERSION) + CCM_KWARGS['version'] = CCM_VERSION -if CASSANDRA_VERSION >= '2.2': - default_protocol_version = 4 -elif CASSANDRA_VERSION >= '2.1': - default_protocol_version = 3 -elif CASSANDRA_VERSION >= '2.0': - default_protocol_version = 2 -else: - default_protocol_version = 1 +ALLOW_BETA_PROTOCOL = False + + +def get_default_protocol(): + if CASSANDRA_VERSION >= Version('4.0-a'): + if DSE_VERSION: + return ProtocolVersion.DSE_V2 + else: + return ProtocolVersion.V5 + if CASSANDRA_VERSION >= Version('3.10'): + if DSE_VERSION: + return ProtocolVersion.DSE_V1 + else: + return 4 + if CASSANDRA_VERSION >= Version('2.2'): + return 4 + elif CASSANDRA_VERSION >= Version('2.1'): + return 3 + elif CASSANDRA_VERSION >= Version('2.0'): + return 2 + else: + raise Exception("Running tests with an unsupported Cassandra version: {0}".format(CASSANDRA_VERSION)) + + +def get_supported_protocol_versions(): + """ + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 + 3.10(C*) -> 5(beta),4,3 + 3.10(DSE) -> DSE_V1,4,3 + 4.0(C*) -> 6(beta),5,4,3 + 4.0(DSE) -> DSE_v2, DSE_V1,4,3 +` """ + if CASSANDRA_VERSION >= Version('4.0-beta5'): + if not DSE_VERSION: + return (3, 4, 5, 6) + if CASSANDRA_VERSION >= Version('4.0-a'): + if DSE_VERSION: + return (3, 4, ProtocolVersion.DSE_V1, ProtocolVersion.DSE_V2) + else: + return (3, 4, 5) + elif CASSANDRA_VERSION >= Version('3.10'): + if DSE_VERSION: + return (3, 4, ProtocolVersion.DSE_V1) + else: + return (3, 4) + elif CASSANDRA_VERSION >= Version('3.0'): + return (3, 4) + elif CASSANDRA_VERSION >= Version('2.2'): + return (1,2, 3, 4) + elif CASSANDRA_VERSION >= Version('2.1'): + return (1, 2, 3) + elif CASSANDRA_VERSION >= Version('2.0'): + return (1, 2) + else: + return (1,) + + +def get_unsupported_lower_protocol(): + """ + This is used to determine the lowest protocol version that is NOT + supported by the version of C* running + """ + if CASSANDRA_VERSION >= Version('3.0'): + return 2 + else: + return None + + +def get_unsupported_upper_protocol(): + """ + This is used to determine the highest protocol version that is NOT + supported by the version of C* running + """ + + if CASSANDRA_VERSION >= Version('4.0-a'): + if DSE_VERSION: + return None + else: + return ProtocolVersion.DSE_V1 + if CASSANDRA_VERSION >= Version('3.10'): + if DSE_VERSION: + return ProtocolVersion.DSE_V2 + else: + return 5 + if CASSANDRA_VERSION >= Version('2.2'): + return 5 + elif CASSANDRA_VERSION >= Version('2.1'): + return 4 + elif CASSANDRA_VERSION >= Version('2.0'): + return 3 + else: + return 2 + + +default_protocol_version = get_default_protocol() + PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) + +def local_decorator_creator(): + if USE_CASS_EXTERNAL or not CASSANDRA_IP.startswith("127.0.0."): + return unittest.skip('Tests only runs against local C*') + + def _id_and_mark(f): + f.local = True + return f + + return _id_and_mark + +local = local_decorator_creator() notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater not supported') greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') - -greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= '2.1', 'Cassandra version 2.1 or greater required') -greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= '2.2', 'Cassandra version 2.2 or greater required') -greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required') -lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required') - +protocolv6 = unittest.skipUnless(6 in get_supported_protocol_versions(), 'Protocol versions less than 6 are not supported') +greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') +greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') +greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.0'), 'Cassandra version 3.0 or greater required') +greaterthanorequalcass31 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.1'), 'Cassandra version 3.1 or greater required') +greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required') +greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required') +greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.11 or greater required') +greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required') +greaterthanorequalcass50 = unittest.skipUnless(CASSANDRA_VERSION >= Version('5.0-beta'), 'Cassandra version 5.0 or greater required') +lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0'), 'Cassandra version less or equal to 4.0 required') +lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0'), 'Cassandra version less than 4.0 required') +lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required') +greaterthanorequaldse68 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.8'), "DSE 6.8 or greater required for this test") +greaterthanorequaldse67 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.7'), "DSE 6.7 or greater required for this test") +greaterthanorequaldse60 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.0'), "DSE 6.0 or greater required for this test") +greaterthanorequaldse51 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('5.1'), "DSE 5.1 or greater required for this test") +greaterthanorequaldse50 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('5.0'), "DSE 5.0 or greater required for this test") +lessthandse51 = unittest.skipUnless(DSE_VERSION and DSE_VERSION < Version('5.1'), "DSE version less than 5.1 required") +lessthandse60 = unittest.skipUnless(DSE_VERSION and DSE_VERSION < Version('6.0'), "DSE version less than 6.0 required") +lessthandse69 = unittest.skipUnless(DSE_VERSION and DSE_VERSION < Version('6.9'), "DSE version less than 6.9 required") + +pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") +requiresmallclockgranularity = unittest.skipIf("Windows" in platform.system() or "asyncore" in EVENT_LOOP_MANAGER, + "This test is not suitible for environments with large clock granularity") +requiressimulacron = unittest.skipIf(SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"), "Simulacron jar hasn't been specified or C* version is 2.0") +requirecassandra = unittest.skipIf(DSE_VERSION, "Cassandra required") +notdse = unittest.skipIf(DSE_VERSION, "DSE not supported") +requiredse = unittest.skipUnless(DSE_VERSION, "DSE required") +requirescloudproxy = unittest.skipIf(CLOUD_PROXY_PATH is None, "Cloud Proxy path hasn't been specified") + +libevtest = unittest.skipUnless(EVENT_LOOP_MANAGER=="libev", "Test timing designed for libev loop") def wait_for_node_socket(node, timeout): binary_itf = node.network_interfaces['binary'] if not common.check_socket_listening(binary_itf, timeout=timeout): - log.warn("Unable to connect to binary socket for node " + node.name) + log.warning("Unable to connect to binary socket for node " + node.name) else: log.debug("Node %s is up and listening " % (node.name,)) @@ -172,6 +393,9 @@ def check_socket_listening(itf, timeout=60): return False +USE_SINGLE_INTERFACE = os.getenv('USE_SINGLE_INTERFACE', False) + + def get_cluster(): return CCM_CLUSTER @@ -180,24 +404,36 @@ def get_node(node_id): return CCM_CLUSTER.nodes['node%s' % node_id] -def use_multidc(dc_list, workloads=[]): +def use_multidc(dc_list, workloads=None): use_cluster(MULTIDC_CLUSTER_NAME, dc_list, start=True, workloads=workloads) -def use_singledc(start=True, workloads=[]): - use_cluster(CLUSTER_NAME, [3], start=start, workloads=workloads) +def use_singledc(start=True, workloads=None, use_single_interface=USE_SINGLE_INTERFACE): + use_cluster(CLUSTER_NAME, [3], start=start, workloads=workloads, use_single_interface=use_single_interface) -def use_single_node(start=True, workloads=[]): - use_cluster(SINGLE_NODE_CLUSTER_NAME, [1], start=start, workloads=workloads) +def use_single_node(start=True, workloads=None, configuration_options=None, dse_options=None): + use_cluster(SINGLE_NODE_CLUSTER_NAME, [1], start=start, workloads=workloads, + configuration_options=configuration_options, dse_options=dse_options) + + +def check_log_error(): + global CCM_CLUSTER + log.debug("Checking log error of cluster {0}".format(CCM_CLUSTER.name)) + for node in CCM_CLUSTER.nodelist(): + errors = node.grep_log_for_errors() + for error in errors: + for line in error: + print(line) def remove_cluster(): - if USE_CASS_EXTERNAL: + if USE_CASS_EXTERNAL or KEEP_TEST_CLUSTER: return global CCM_CLUSTER if CCM_CLUSTER: + check_log_error() log.debug("Removing cluster {0}".format(CCM_CLUSTER.name)) tries = 0 while tries < 100: @@ -207,7 +443,7 @@ def remove_cluster(): return except OSError: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 time.sleep(1) @@ -215,25 +451,61 @@ def remove_cluster(): raise RuntimeError("Failed to remove cluster after 100 attempts") -def is_current_cluster(cluster_name, node_counts): +def is_current_cluster(cluster_name, node_counts, workloads): global CCM_CLUSTER if CCM_CLUSTER and CCM_CLUSTER.name == cluster_name: if [len(list(nodes)) for dc, nodes in groupby(CCM_CLUSTER.nodelist(), lambda n: n.data_center)] == node_counts: + for node in CCM_CLUSTER.nodelist(): + if set(node.workloads) != set(workloads): + print("node workloads don't match creating new cluster") + return False return True return False -def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=[]): +def start_cluster_wait_for_up(cluster): + cluster.start(wait_for_binary_proto=True) + # Added to wait for slow nodes to start up + log.debug("Cluster started waiting for binary ports") + for node in CCM_CLUSTER.nodes.values(): + wait_for_node_socket(node, 300) + log.debug("Binary port are open") + + +def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, set_keyspace=True, ccm_options=None, + configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE): + configuration_options = configuration_options or {} + dse_options = dse_options or {} + workloads = workloads or [] + dse_cluster = True if DSE_VERSION else False + hcd_cluster = True if HCD_VERSION else False + + if ccm_options is None and (DSE_VERSION or HCD_VERSION): + ccm_options = {"version": CCM_VERSION} + elif ccm_options is None: + ccm_options = CCM_KWARGS.copy() + + cassandra_version = ccm_options.get('version', CCM_VERSION) + dse_version = ccm_options.get('version', DSE_VERSION) + global CCM_CLUSTER if USE_CASS_EXTERNAL: if CCM_CLUSTER: log.debug("Using external CCM cluster {0}".format(CCM_CLUSTER.name)) else: - log.debug("Using unnamed external cluster") + ccm_path = os.getenv("CCM_PATH", None) + ccm_name = os.getenv("CCM_NAME", None) + if ccm_path and ccm_name: + CCM_CLUSTER = CCMClusterFactory.load(ccm_path, ccm_name) + log.debug("Using external CCM cluster {0}".format(CCM_CLUSTER.name)) + else: + log.debug("Using unnamed external cluster") + if set_keyspace and start: + setup_keyspace(ipformat=ipformat, wait=False) return - if is_current_cluster(cluster_name, nodes): + if is_current_cluster(cluster_name, nodes, workloads): log.debug("Using existing cluster, matching topology: {0}".format(cluster_name)) else: if CCM_CLUSTER: @@ -244,59 +516,145 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=[]): CCM_CLUSTER = CCMClusterFactory.load(path, cluster_name) log.debug("Found existing CCM cluster, {0}; clearing.".format(cluster_name)) CCM_CLUSTER.clear() - CCM_CLUSTER.set_install_dir(**CCM_KWARGS) + CCM_CLUSTER.set_install_dir(**ccm_options) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.set_dse_configuration_options(dse_options) except Exception: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb - log.debug("Creating new CCM cluster, {0}, with args {1}".format(cluster_name, CCM_KWARGS)) - if DSE_VERSION: - log.error("creating dse cluster") - CCM_CLUSTER = DseCluster(path, cluster_name, **CCM_KWARGS) - else: - CCM_CLUSTER = CCMCluster(path, cluster_name, **CCM_KWARGS) - CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) - if CASSANDRA_VERSION >= '2.2': - CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) - if CASSANDRA_VERSION >= '3.0': + ccm_options.update(cmd_line_args_to_dict('CCM_ARGS')) + + log.debug("Creating new CCM cluster, {0}, with args {1}".format(cluster_name, ccm_options)) + + # Make sure we cleanup old cluster dir if it exists + cluster_path = os.path.join(path, cluster_name) + if os.path.exists(cluster_path): + shutil.rmtree(cluster_path) + + if dse_cluster: + CCM_CLUSTER = DseCluster(path, cluster_name, **ccm_options) + CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) + CCM_CLUSTER.set_configuration_options({'batch_size_warn_threshold_in_kb': 5}) + if Version(dse_version) >= Version('5.0'): + CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) - if 'spark' in workloads: - config_options = {"initial_spark_worker_resources": 0.1} - CCM_CLUSTER.set_dse_configuration_options(config_options) - common.switch_cluster(path, cluster_name) - CCM_CLUSTER.populate(nodes, ipformat=ipformat) + if Version(dse_version) >= Version('5.1'): + # For Inet4Address + CCM_CLUSTER.set_dse_configuration_options({ + 'graph': { + 'gremlin_server': { + 'scriptEngines': { + 'gremlin-groovy': { + 'config': { + 'sandbox_rules': { + 'whitelist_packages': ['java.net'] + } + } + } + } + } + } + }) + if 'spark' in workloads: + if Version(dse_version) >= Version('6.8'): + config_options = { + "resource_manager_options": { + "worker_options": { + "cores_total": 0.1, + "memory_total": "64M" + } + } + } + else: + config_options = {"initial_spark_worker_resources": 0.1} + + if Version(dse_version) >= Version('6.7'): + log.debug("Disabling AlwaysON SQL for a DSE 6.7 Cluster") + config_options['alwayson_sql_options'] = {'enabled': False} + CCM_CLUSTER.set_dse_configuration_options(config_options) + common.switch_cluster(path, cluster_name) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.populate(nodes, ipformat=ipformat) + + CCM_CLUSTER.set_dse_configuration_options(dse_options) + elif hcd_cluster: + CCM_CLUSTER = HcdCluster(path, cluster_name, **ccm_options) + CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) + CCM_CLUSTER.set_configuration_options({'batch_size_warn_threshold_in_kb': 5}) + CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) + CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) + CCM_CLUSTER.set_configuration_options({'enable_materialized_views': True}) + CCM_CLUSTER.set_configuration_options({'enable_sasi_indexes': True}) + CCM_CLUSTER.set_configuration_options({'enable_transient_replication': True}) + common.switch_cluster(path, cluster_name) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.populate(nodes, ipformat=ipformat, use_single_interface=use_single_interface) + else: + ccm_cluster_clz = CCMCluster if Version(cassandra_version) < Version('4.1') else Cassandra41CCMCluster + CCM_CLUSTER = ccm_cluster_clz(path, cluster_name, **ccm_options) + CCM_CLUSTER.set_configuration_options({'start_native_transport': True}) + if Version(cassandra_version) >= Version('2.2'): + CCM_CLUSTER.set_configuration_options({'enable_user_defined_functions': True}) + if Version(cassandra_version) >= Version('3.0'): + # The config.yml option below is deprecated in C* 4.0 per CASSANDRA-17280 + if Version(cassandra_version) < Version('4.0'): + CCM_CLUSTER.set_configuration_options({'enable_scripted_user_defined_functions': True}) + else: + # Cassandra version >= 4.0 + CCM_CLUSTER.set_configuration_options({ + 'enable_materialized_views': True, + 'enable_sasi_indexes': True, + 'enable_transient_replication': True, + }) + + common.switch_cluster(path, cluster_name) + CCM_CLUSTER.set_configuration_options(configuration_options) + CCM_CLUSTER.populate(nodes, ipformat=ipformat, use_single_interface=use_single_interface) + try: jvm_args = [] + # This will enable the Mirroring query handler which will echo our custom payload k,v pairs back - if PROTOCOL_VERSION >= 4: - jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"] + if 'graph' in workloads: + jvm_args += ['-Xms1500M', '-Xmx1500M'] + else: + if PROTOCOL_VERSION >= 4: + jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"] + if len(workloads) > 0: + for node in CCM_CLUSTER.nodes.values(): + node.set_workloads(workloads) if start: - if(len(workloads) > 0): - for node in CCM_CLUSTER.nodes.values(): - node.set_workloads(workloads) log.debug("Starting CCM cluster: {0}".format(cluster_name)) - CCM_CLUSTER.start(wait_for_binary_proto=True, wait_other_notice=True, jvm_args=jvm_args) + CCM_CLUSTER.start(jvm_args=jvm_args, wait_for_binary_proto=True) # Added to wait for slow nodes to start up + log.debug("Cluster started waiting for binary ports") for node in CCM_CLUSTER.nodes.values(): - wait_for_node_socket(node, 120) - setup_keyspace(ipformat=ipformat) + wait_for_node_socket(node, 300) + log.debug("Binary ports are open") + if set_keyspace: + args = {"ipformat": ipformat} + if use_single_interface: + args["port"] = DEFAULT_SINGLE_INTERFACE_PORT + setup_keyspace(**args) except Exception: log.exception("Failed to start CCM cluster; removing cluster.") if os.name == "nt": if CCM_CLUSTER: - for node in CCM_CLUSTER.nodes.itervalues(): + for node in CCM_CLUSTER.nodes.items(): os.system("taskkill /F /PID " + str(node.pid)) else: call(["pkill", "-9", "-f", ".ccm"]) remove_cluster() raise + return CCM_CLUSTER def teardown_package(): - if USE_CASS_EXTERNAL: + if USE_CASS_EXTERNAL or KEEP_TEST_CLUSTER: return # when multiple modules are run explicitly, this runs between them # need to make sure CCM_CLUSTER is properly cleared for that case @@ -319,13 +677,13 @@ def execute_until_pass(session, query): while tries < 100: try: return session.execute(query) - except (ConfigurationException, AlreadyExists): - log.warn("Recieved already exists from query {0} not exiting".format(query)) + except (ConfigurationException, AlreadyExists, InvalidRequest): + log.warning("Received already exists from query {0} not exiting".format(query)) # keyspace/table was already created/dropped return except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 @@ -338,38 +696,61 @@ def execute_with_long_wait_retry(session, query, timeout=30): try: return session.execute(query, timeout=timeout) except (ConfigurationException, AlreadyExists): - log.warn("Recieved already exists from query {0} not exiting".format(query)) + log.warning("Received already exists from query {0} not exiting".format(query)) # keyspace/table was already created/dropped return except (OperationTimedOut, ReadTimeout, ReadFailure, WriteTimeout, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) +def execute_with_retry_tolerant(session, query, retry_exceptions, escape_exception): + # TODO refactor above methods into this one for code reuse + tries = 0 + while tries < 100: + try: + tries += 1 + rs = session.execute(query) + return rs + except escape_exception: + return + except retry_exceptions: + time.sleep(.1) + + raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) + + def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): try: execute_with_long_wait_retry(session, "DROP KEYSPACE {0}".format(keyspace_name)) except: - log.warn("Error encountered when droping keyspace {0}".format(keyspace_name)) + log.warning("Error encountered when dropping keyspace {0}".format(keyspace_name)) ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb - log.warn("Shutting down cluster") - cluster.shutdown() + finally: + log.warning("Shutting down cluster") + cluster.shutdown() -def setup_keyspace(ipformat=None): +def setup_keyspace(ipformat=None, wait=True, protocol_version=None, port=9042): # wait for nodes to startup - time.sleep(10) + if wait: + time.sleep(10) + + if protocol_version: + _protocol_version = protocol_version + else: + _protocol_version = PROTOCOL_VERSION if not ipformat: - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster(protocol_version=_protocol_version, port=port) else: - cluster = Cluster(contact_points=["::1"], protocol_version=PROTOCOL_VERSION) + cluster = TestCluster(contact_points=["::1"], protocol_version=_protocol_version, port=port) session = cluster.connect() try: @@ -392,11 +773,17 @@ def setup_keyspace(ipformat=None): WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}''' execute_with_long_wait_retry(session, ddl) - ddl = ''' + ddl_3f = ''' CREATE TABLE test3rf.test ( k int PRIMARY KEY, v int )''' - execute_with_long_wait_retry(session, ddl) + execute_with_long_wait_retry(session, ddl_3f) + + ddl_1f = ''' + CREATE TABLE test1rf.test ( + k int PRIMARY KEY, + v int )''' + execute_with_long_wait_retry(session, ddl_1f) except Exception: traceback.print_exc() @@ -456,9 +843,9 @@ def create_keyspace(cls, rf): execute_with_long_wait_retry(cls.session, ddl) @classmethod - def common_setup(cls, rf, keyspace_creation=True, create_class_table=False): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - cls.session = cls.cluster.connect() + def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, **cluster_kwargs): + cls.cluster = TestCluster(**cluster_kwargs) + cls.session = cls.cluster.connect(wait_for_all_pools=True) cls.ks_name = cls.__name__.lower() if keyspace_creation: cls.create_keyspace(rf) @@ -503,6 +890,29 @@ def reset(self): 'critical': [], } + def get_message_count(self, level, sub_string): + count = 0 + for msg in self.messages.get(level): + if sub_string in msg: + count+=1 + return count + + def set_module_name(self, module_name): + """ + This is intended to be used doing: + with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: + """ + self.module_name = module_name + return self + + def __enter__(self): + self.logger = logging.getLogger(self.module_name) + self.logger.addHandler(self) + return self + + def __exit__(self, *args): + pass + class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): """ @@ -531,10 +941,10 @@ def tearDownClass(cls): drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster) -class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase): +class BasicSharedKeyspaceUnitTestCaseRF1(BasicSharedKeyspaceUnitTestCase): """ This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. - creates a keyspace named after the testclass with a rf of 1, and a table named after the class + creates a keyspace named after the testclass with a rf of 1 """ @classmethod def setUpClass(self): @@ -551,16 +961,6 @@ def setUpClass(self): self.common_setup(2) -class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase): - """ - This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. - creates a keyspace named after the testc lass with a rf of 2, and a table named after the class - """ - @classmethod - def setUpClass(self): - self.common_setup(2, True) - - class BasicSharedKeyspaceUnitTestCaseRF3(BasicSharedKeyspaceUnitTestCase): """ This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. @@ -571,14 +971,18 @@ def setUpClass(self): self.common_setup(3) -class BasicSharedKeyspaceUnitTestCaseRF3WTable(BasicSharedKeyspaceUnitTestCase): +class BasicSharedKeyspaceUnitTestCaseRF3WM(BasicSharedKeyspaceUnitTestCase): """ This is basic unit test case that can be leveraged to scope a keyspace to a specific test class. - creates a keyspace named after the test class with a rf of 3 and a table named after the class + creates a keyspace named after the test class with a rf of 3 with metrics enabled """ @classmethod def setUpClass(self): - self.common_setup(3, True) + self.common_setup(3, True, True, metrics_enabled=True) + + @classmethod + def tearDownClass(cls): + drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster) class BasicSharedKeyspaceUnitTestCaseWFunctionTable(BasicSharedKeyspaceUnitTestCase): @@ -619,3 +1023,62 @@ def setUp(self): def tearDown(self): self.cluster.shutdown() + + +def assert_startswith(s, prefix): + if not s.startswith(prefix): + raise AssertionError( + '{} does not start with {}'.format(repr(s), repr(prefix)) + ) + + +class TestCluster(object): + __test__ = False + DEFAULT_PROTOCOL_VERSION = default_protocol_version + DEFAULT_CASSANDRA_IP = CASSANDRA_IP + DEFAULT_ALLOW_BETA = ALLOW_BETA_PROTOCOL + + def __new__(cls, **kwargs): + if 'protocol_version' not in kwargs: + kwargs['protocol_version'] = cls.DEFAULT_PROTOCOL_VERSION + if 'contact_points' not in kwargs: + kwargs['contact_points'] = [cls.DEFAULT_CASSANDRA_IP] + if 'allow_beta_protocol_version' not in kwargs: + kwargs['allow_beta_protocol_version'] = cls.DEFAULT_ALLOW_BETA + return Cluster(**kwargs) + +# Subclass of CCMCluster (i.e. ccmlib.cluster.Cluster) which transparently performs +# conversion of cassandra.yml directives into something matching the new syntax +# introduced by CASSANDRA-15234 +class Cassandra41CCMCluster(CCMCluster): + __test__ = False + IN_MS_REGEX = re.compile('^(\w+)_in_ms$') + IN_KB_REGEX = re.compile('^(\w+)_in_kb$') + ENABLE_REGEX = re.compile('^enable_(\w+)$') + + def _get_config_key(self, k, v): + if "." in k: + return k + m = self.IN_MS_REGEX.match(k) + if m: + return m.group(1) + m = self.ENABLE_REGEX.search(k) + if m: + return "%s_enabled" % (m.group(1)) + m = self.IN_KB_REGEX.match(k) + if m: + return m.group(1) + return k + + def _get_config_val(self, k, v): + m = self.IN_MS_REGEX.match(k) + if m: + return "%sms" % (v) + m = self.IN_KB_REGEX.match(k) + if m: + return "%sKiB" % (v) + return v + + def set_configuration_options(self, values=None, *args, **kwargs): + new_values = {self._get_config_key(k, str(v)):self._get_config_val(k, str(v)) for (k,v) in values.items()} + super(Cassandra41CCMCluster, self).set_configuration_options(values=new_values, *args, **kwargs) \ No newline at end of file diff --git a/tests/integration/advanced/__init__.py b/tests/integration/advanced/__init__.py new file mode 100644 index 0000000000..b1ed70f157 --- /dev/null +++ b/tests/integration/advanced/__init__.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from urllib.request import build_opener, Request, HTTPHandler +import re +import os +import time +from os.path import expanduser + +from ccmlib import common + +from tests.integration import get_server_versions, BasicKeyspaceUnitTestCase, \ + drop_keyspace_shutdown_cluster, get_node, USE_CASS_EXTERNAL, TestCluster +from tests.integration import use_singledc, use_single_node, wait_for_node_socket, CASSANDRA_IP + +home = expanduser('~') + +# Home directory of the Embedded Apache Directory Server to use +ADS_HOME = os.getenv('ADS_HOME', home) + + +def find_spark_master(session): + + # Iterate over the nodes the one with port 7080 open is the spark master + for host in session.hosts: + ip = host.address + port = 7077 + spark_master = (ip, port) + if common.check_socket_listening(spark_master, timeout=3): + return spark_master[0] + return None + + +def wait_for_spark_workers(num_of_expected_workers, timeout): + """ + This queries the spark master and checks for the expected number of workers + """ + start_time = time.time() + while True: + opener = build_opener(HTTPHandler) + request = Request("http://{0}:7080".format(CASSANDRA_IP)) + request.get_method = lambda: 'GET' + connection = opener.open(request) + match = re.search('Alive Workers:.*(\d+)', connection.read().decode('utf-8')) + num_workers = int(match.group(1)) + if num_workers == num_of_expected_workers: + match = True + break + elif time.time() - start_time > timeout: + match = True + break + time.sleep(1) + return match + + +def use_single_node_with_graph(start=True, options={}, dse_options={}): + use_single_node(start=start, workloads=['graph'], configuration_options=options, dse_options=dse_options) + + +def use_single_node_with_graph_and_spark(start=True, options={}): + use_single_node(start=start, workloads=['graph', 'spark'], configuration_options=options) + + +def use_single_node_with_graph_and_solr(start=True, options={}): + use_single_node(start=start, workloads=['graph', 'solr'], configuration_options=options) + + +def use_singledc_wth_graph(start=True): + use_singledc(start=start, workloads=['graph']) + + +def use_singledc_wth_graph_and_spark(start=True): + use_cluster_with_graph(3) + + +def use_cluster_with_graph(num_nodes): + """ + This is a workaround to account for the fact that spark nodes will conflict over master assignment + when started all at once. + """ + if USE_CASS_EXTERNAL: + return + + # Create the cluster but don't start it. + use_singledc(start=False, workloads=['graph', 'spark']) + # Start first node. + get_node(1).start(wait_for_binary_proto=True) + # Wait binary protocol port to open + wait_for_node_socket(get_node(1), 120) + # Wait for spark master to start up + spark_master_http = ("localhost", 7080) + common.check_socket_listening(spark_master_http, timeout=60) + tmp_cluster = TestCluster() + + # Start up remaining nodes. + try: + session = tmp_cluster.connect() + statement = "ALTER KEYSPACE dse_leases WITH REPLICATION = {'class': 'NetworkTopologyStrategy', 'dc1': '%d'}" % (num_nodes) + session.execute(statement) + finally: + tmp_cluster.shutdown() + + for i in range(1, num_nodes+1): + if i is not 1: + node = get_node(i) + node.start(wait_for_binary_proto=True) + wait_for_node_socket(node, 120) + + # Wait for workers to show up as Alive on master + wait_for_spark_workers(3, 120) + + +class BasicGeometricUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This base test class is used by all the geometric tests. It contains class level teardown and setup + methods. It also contains the test fixtures used by those tests + """ + + @classmethod + def common_dse_setup(cls, rf, keyspace_creation=True): + cls.cluster = TestCluster() + cls.session = cls.cluster.connect() + cls.ks_name = cls.__name__.lower() + if keyspace_creation: + cls.create_keyspace(rf) + cls.cass_version, cls.cql_version = get_server_versions() + cls.session.set_keyspace(cls.ks_name) + + @classmethod + def setUpClass(cls): + cls.common_dse_setup(1) + cls.initalizeTables() + + @classmethod + def tearDownClass(cls): + drop_keyspace_shutdown_cluster(cls.ks_name, cls.session, cls.cluster) + + @classmethod + def initalizeTables(cls): + udt_type = "CREATE TYPE udt1 (g {0})".format(cls.cql_type_name) + large_table = "CREATE TABLE tbl (k uuid PRIMARY KEY, g {0}, l list<{0}>, s set<{0}>, m0 map<{0},int>, m1 map, t tuple<{0},{0},{0}>, u frozen)".format( + cls.cql_type_name) + simple_table = "CREATE TABLE tblpk (k {0} primary key, v int)".format(cls.cql_type_name) + cluster_table = "CREATE TABLE tblclustering (k0 int, k1 {0}, v int, primary key (k0, k1))".format( + cls.cql_type_name) + cls.session.execute(udt_type) + cls.session.execute(large_table) + cls.session.execute(simple_table) + cls.session.execute(cluster_table) diff --git a/tests/integration/advanced/graph/__init__.py b/tests/integration/advanced/graph/__init__.py new file mode 100644 index 0000000000..71554c9bad --- /dev/null +++ b/tests/integration/advanced/graph/__init__.py @@ -0,0 +1,1197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import inspect +from packaging.version import Version +import ipaddress +from uuid import UUID +from decimal import Decimal +import datetime + +from cassandra.util import Point, LineString, Polygon, Duration + +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT, EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT +from cassandra.cluster import GraphAnalyticsExecutionProfile, GraphExecutionProfile, EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, \ + default_lbp_factory +from cassandra.policies import DSELoadBalancingPolicy + +from cassandra.graph import GraphSON1Deserializer +from cassandra.graph.graphson import InetTypeIO, GraphSON2Deserializer, GraphSON3Deserializer +from cassandra.graph import Edge, Vertex, Path +from cassandra.graph.query import GraphOptions, GraphProtocol, graph_graphson2_row_factory, \ + graph_graphson3_row_factory + +from tests.integration import DSE_VERSION +from tests.integration.advanced import * + + +def setup_module(): + if DSE_VERSION: + dse_options = {'graph': {'realtime_evaluation_timeout_in_seconds': 60}} + use_single_node_with_graph(dse_options=dse_options) + + +log = logging.getLogger(__name__) + +MAX_LONG = 9223372036854775807 +MIN_LONG = -9223372036854775808 +ZERO_LONG = 0 + +MAKE_STRICT = "schema.config().option('graph.schema_mode').set('production')" +MAKE_NON_STRICT = "schema.config().option('graph.schema_mode').set('development')" +ALLOW_SCANS = "schema.config().option('graph.allow_scan').set('true')" + +deserializer_plus_to_ipaddressv4 = lambda x: ipaddress.IPv4Address(GraphSON1Deserializer.deserialize_inet(x)) +deserializer_plus_to_ipaddressv6 = lambda x: ipaddress.IPv6Address(GraphSON1Deserializer.deserialize_inet(x)) + + +def generic_ip_deserializer(string_ip_address): + if ":" in string_ip_address: + return deserializer_plus_to_ipaddressv6(string_ip_address) + return deserializer_plus_to_ipaddressv4(string_ip_address) + + +class GenericIpAddressIO(InetTypeIO): + @classmethod + def deserialize(cls, value, reader=None): + return generic_ip_deserializer(value) + +GraphSON2Deserializer._deserializers[GenericIpAddressIO.graphson_type] = GenericIpAddressIO +GraphSON3Deserializer._deserializers[GenericIpAddressIO.graphson_type] = GenericIpAddressIO + +if DSE_VERSION: + if DSE_VERSION >= Version('6.8.0'): + CREATE_CLASSIC_GRAPH = "system.graph(name).engine(Classic).create()" + else: + CREATE_CLASSIC_GRAPH = "system.graph(name).create()" + + +def reset_graph(session, graph_name): + ks = list(session.execute( + "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = '{}';".format(graph_name))) + if ks: + try: + session.execute_graph('system.graph(name).drop()', {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + except: + pass + + session.execute_graph(CREATE_CLASSIC_GRAPH, {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + wait_for_graph_inserted(session, graph_name) + + +def wait_for_graph_inserted(session, graph_name): + count = 0 + exists = session.execute_graph('system.graph(name).exists()', {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0].value + while not exists and count < 50: + time.sleep(1) + exists = session.execute_graph('system.graph(name).exists()', {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT)[0].value + return exists + + +class BasicGraphUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This is basic graph unit test case that provides various utility methods that can be leveraged for testcase setup and tear + down + """ + + @property + def graph_name(self): + return self._testMethodName.lower() + + def session_setup(self): + lbp = DSELoadBalancingPolicy(default_lbp_factory()) + + ep_graphson2 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_2_0 + ), + row_factory=graph_graphson2_row_factory) + + ep_graphson3 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_3_0 + ), + row_factory=graph_graphson3_row_factory) + + ep_graphson1 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name + ) + ) + + ep_analytics = GraphAnalyticsExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_source=b'a', + graph_language=b'gremlin-groovy', + graph_name=self.graph_name + ) + ) + + self.cluster = TestCluster(execution_profiles={ + EXEC_PROFILE_GRAPH_DEFAULT: ep_graphson1, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: ep_analytics, + "graphson1": ep_graphson1, + "graphson2": ep_graphson2, + "graphson3": ep_graphson3 + }) + + self.session = self.cluster.connect() + self.ks_name = self._testMethodName.lower() + self.cass_version, self.cql_version = get_server_versions() + + def setUp(self): + self.session_setup() + self.reset_graph() + self.clear_schema() + # enable dev and scan modes + self.session.execute_graph(MAKE_NON_STRICT) + self.session.execute_graph(ALLOW_SCANS) + + def tearDown(self): + self.cluster.shutdown() + + def clear_schema(self): + self.session.execute_graph(""" + schema.clear(); + """) + + def reset_graph(self): + reset_graph(self.session, self.graph_name) + + def wait_for_graph_inserted(self): + wait_for_graph_inserted(self.session, self.graph_name) + + def _execute(self, query, graphson, params=None, execution_profile_options=None, **kwargs): + queries = query if isinstance(query, list) else [query] + ep = self.get_execution_profile(graphson) + if execution_profile_options: + ep = self.session.execution_profile_clone_update(ep, **execution_profile_options) + + results = [] + for query in queries: + log.debug(query) + rf = self.session.execute_graph_async(query, parameters=params, execution_profile=ep, **kwargs) + results.append(rf.result()) + self.assertEqual(rf.message.custom_payload['graph-results'], graphson) + + return results[0] if len(results) == 1 else results + + def get_execution_profile(self, graphson, traversal=False): + ep = 'graphson1' + if graphson == GraphProtocol.GRAPHSON_2_0: + ep = 'graphson2' + elif graphson == GraphProtocol.GRAPHSON_3_0: + ep = 'graphson3' + + return ep if traversal is False else 'traversal_' + ep + + def resultset_to_list(self, rs): + results_list = [] + for result in rs: + try: + results_list.append(result.value) + except: + results_list.append(result) + + return results_list + + +class GraphUnitTestCase(BasicKeyspaceUnitTestCase): + + @property + def graph_name(self): + return self._testMethodName.lower() + + def session_setup(self): + lbp = DSELoadBalancingPolicy(default_lbp_factory()) + + ep_graphson2 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_2_0 + ), + row_factory=graph_graphson2_row_factory) + + ep_graphson3 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_3_0 + ), + row_factory=graph_graphson3_row_factory) + + ep_graphson1 = GraphExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_name=self.graph_name, + graph_language='gremlin-groovy' + ) + ) + + ep_analytics = GraphAnalyticsExecutionProfile( + request_timeout=60, + load_balancing_policy=lbp, + graph_options=GraphOptions( + graph_source=b'a', + graph_language=b'gremlin-groovy', + graph_name=self.graph_name + ) + ) + + self.cluster = TestCluster(execution_profiles={ + EXEC_PROFILE_GRAPH_DEFAULT: ep_graphson1, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT: ep_analytics, + "graphson1": ep_graphson1, + "graphson2": ep_graphson2, + "graphson3": ep_graphson3 + }) + + self.session = self.cluster.connect() + self.ks_name = self._testMethodName.lower() + self.cass_version, self.cql_version = get_server_versions() + + def setUp(self): + """basic setup only""" + self.session_setup() + + def setup_graph(self, schema): + """Config dependant setup""" + schema.drop_graph(self.session, self.graph_name) + schema.create_graph(self.session, self.graph_name) + schema.clear(self.session) + if schema is ClassicGraphSchema: + # enable dev and scan modes + self.session.execute_graph(MAKE_NON_STRICT) + self.session.execute_graph(ALLOW_SCANS) + + def teardown_graph(self, schema): + schema.drop_graph(self.session, self.graph_name) + + def tearDown(self): + self.cluster.shutdown() + + def execute_graph_queries(self, queries, params=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, + verify_graphson=False, **kwargs): + results = [] + for query in queries: + log.debug(query) + rf = self.session.execute_graph_async(query, parameters=params, + execution_profile=execution_profile, **kwargs) + if verify_graphson: + self.assertEqual(rf.message.custom_payload['graph-results'], verify_graphson) + results.append(rf.result()) + + return results + + def execute_graph(self, query, graphson, params=None, execution_profile_options=None, traversal=False, **kwargs): + queries = query if isinstance(query, list) else [query] + ep = self.get_execution_profile(graphson) + if traversal: + ep = 'traversal_' + ep + if execution_profile_options: + ep = self.session.execution_profile_clone_update(ep, **execution_profile_options) + + results = self.execute_graph_queries(queries, params, ep, verify_graphson=graphson, **kwargs) + + return results[0] if len(results) == 1 else results + + def get_execution_profile(self, graphson, traversal=False): + ep = 'graphson1' + if graphson == GraphProtocol.GRAPHSON_2_0: + ep = 'graphson2' + elif graphson == GraphProtocol.GRAPHSON_3_0: + ep = 'graphson3' + + return ep if traversal is False else 'traversal_' + ep + + def resultset_to_list(self, rs): + results_list = [] + for result in rs: + try: + results_list.append(result.value) + except: + results_list.append(result) + + return results_list + + +class BasicSharedGraphUnitTestCase(BasicKeyspaceUnitTestCase): + """ + This is basic graph unit test case that provides various utility methods that can be leveraged for testcase setup and tear + down + """ + + @classmethod + def session_setup(cls): + cls.cluster = TestCluster() + cls.session = cls.cluster.connect() + cls.ks_name = cls.__name__.lower() + cls.cass_version, cls.cql_version = get_server_versions() + cls.graph_name = cls.__name__.lower() + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + cls.session_setup() + cls.reset_graph() + profiles = cls.cluster.profile_manager.profiles + profiles[EXEC_PROFILE_GRAPH_DEFAULT].request_timeout = 60 + profiles[EXEC_PROFILE_GRAPH_DEFAULT].graph_options.graph_name = cls.graph_name + profiles[EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT].request_timeout = 60 + profiles[EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT].graph_options.graph_name = cls.graph_name + + @classmethod + def tearDownClass(cls): + if DSE_VERSION: + cls.cluster.shutdown() + + @classmethod + def clear_schema(self): + self.session.execute_graph('schema.clear()') + + @classmethod + def reset_graph(self): + reset_graph(self.session, self.graph_name) + + def wait_for_graph_inserted(self): + wait_for_graph_inserted(self.session, self.graph_name) + + +class GraphFixtures(object): + + @staticmethod + def line(length, single_script=True): + raise NotImplementedError() + + @staticmethod + def classic(): + raise NotImplementedError() + + @staticmethod + def multiple_fields(): + raise NotImplementedError() + + @staticmethod + def large(): + raise NotImplementedError() + + +class ClassicGraphFixtures(GraphFixtures): + + @staticmethod + def datatypes(): + data = { + "boolean1": ["Boolean()", True, None], + "boolean2": ["Boolean()", False, None], + "point1": ["Point()", Point(.5, .13), GraphSON1Deserializer.deserialize_point], + "point2": ["Point()", Point(-5, .0), GraphSON1Deserializer.deserialize_point], + + "linestring1": ["Linestring()", LineString(((1.0, 2.0), (3.0, 4.0), (-89.0, 90.0))), + GraphSON1Deserializer.deserialize_linestring], + "polygon1": ["Polygon()", Polygon([(10.0, 10.0), (80.0, 10.0), (80., 88.0), (10., 89.0), (10., 10.0)], + [[(20., 20.0), (20., 30.0), (30., 30.0), (30., 20.0), (20., 20.0)], + [(40., 20.0), (40., 30.0), (50., 30.0), (50., 20.0), (40., 20.0)]]), + GraphSON1Deserializer.deserialize_polygon], + "int1": ["Int()", 2, GraphSON1Deserializer.deserialize_int], + "smallint1": ["Smallint()", 1, GraphSON1Deserializer.deserialize_smallint], + "bigint1": ["Bigint()", MAX_LONG, GraphSON1Deserializer.deserialize_bigint], + "bigint2": ["Bigint()", MIN_LONG, GraphSON1Deserializer.deserialize_bigint], + "bigint3": ["Bigint()", ZERO_LONG, GraphSON1Deserializer.deserialize_bigint], + "varint1": ["Varint()", 2147483647, GraphSON1Deserializer.deserialize_varint], + "int1": ["Int()", 100, GraphSON1Deserializer.deserialize_int], + "float1": ["Float()", 0.3415681, GraphSON1Deserializer.deserialize_float], + "double1": ["Double()", 0.34156811237335205, GraphSON1Deserializer.deserialize_double], + "uuid1": ["Uuid()", UUID('12345678123456781234567812345678'), GraphSON1Deserializer.deserialize_uuid], + "decimal1": ["Decimal()", Decimal(10), GraphSON1Deserializer.deserialize_decimal], + "blob1": ["Blob()", bytearray(b"Hello World"), GraphSON1Deserializer.deserialize_blob], + + "timestamp1": ["Timestamp()", datetime.datetime.utcnow().replace(microsecond=0), + GraphSON1Deserializer.deserialize_timestamp], + "timestamp2": ["Timestamp()", datetime.datetime.max.replace(microsecond=0), + GraphSON1Deserializer.deserialize_timestamp], + # These are valid values but are pending for DSP-14093 to be fixed + #"timestamp3": ["Timestamp()", datetime.datetime(159, 1, 1, 23, 59, 59), + # GraphSON1TypeDeserializer.deserialize_timestamp], + #"timestamp4": ["Timestamp()", datetime.datetime.min, + # GraphSON1TypeDeserializer.deserialize_timestamp], + "inet1": ["Inet()", ipaddress.IPv4Address(u"127.0.0.1"), deserializer_plus_to_ipaddressv4], + "inet2": ["Inet()", ipaddress.IPv6Address(u"2001:db8:85a3:8d3:1319:8a2e:370:7348"), + deserializer_plus_to_ipaddressv6], + "duration1": ["Duration()", datetime.timedelta(1, 16, 0), + GraphSON1Deserializer.deserialize_duration], + "duration2": ["Duration()", datetime.timedelta(days=1, seconds=16, milliseconds=15), + GraphSON1Deserializer.deserialize_duration], + "blob3": ["Blob()", bytes(b"Hello World Again"), GraphSON1Deserializer.deserialize_blob], + "blob4": ["Blob()", memoryview(b"And Again Hello World"), GraphSON1Deserializer.deserialize_blob] + } + + if DSE_VERSION >= Version("5.1"): + data["time1"] = ["Time()", datetime.time(12, 6, 12, 444), GraphSON1Deserializer.deserialize_time] + data["time2"] = ["Time()", datetime.time(12, 6, 12), GraphSON1Deserializer.deserialize_time] + data["time3"] = ["Time()", datetime.time(12, 6), GraphSON1Deserializer.deserialize_time] + data["time4"] = ["Time()", datetime.time.min, GraphSON1Deserializer.deserialize_time] + data["time5"] = ["Time()", datetime.time.max, GraphSON1Deserializer.deserialize_time] + data["blob5"] = ["Blob()", bytearray(b"AKDLIElksadlaswqA" * 10000), GraphSON1Deserializer.deserialize_blob] + data["datetime1"] = ["Date()", datetime.date.today(), GraphSON1Deserializer.deserialize_date] + data["datetime2"] = ["Date()", datetime.date(159, 1, 3), GraphSON1Deserializer.deserialize_date] + data["datetime3"] = ["Date()", datetime.date.min, GraphSON1Deserializer.deserialize_date] + data["datetime4"] = ["Date()", datetime.date.max, GraphSON1Deserializer.deserialize_date] + data["time1"] = ["Time()", datetime.time(12, 6, 12, 444), GraphSON1Deserializer.deserialize_time] + data["time2"] = ["Time()", datetime.time(12, 6, 12), GraphSON1Deserializer.deserialize_time] + data["time3"] = ["Time()", datetime.time(12, 6), GraphSON1Deserializer.deserialize_time] + data["time4"] = ["Time()", datetime.time.min, GraphSON1Deserializer.deserialize_time] + data["time5"] = ["Time()", datetime.time.max, GraphSON1Deserializer.deserialize_time] + + return data + + @staticmethod + def line(length, single_script=False): + queries = [ALLOW_SCANS + ';', + """schema.propertyKey('index').Int().ifNotExists().create(); + schema.propertyKey('distance').Int().ifNotExists().create(); + schema.vertexLabel('lp').properties('index').ifNotExists().create(); + schema.edgeLabel('goesTo').properties('distance').connection('lp', 'lp').ifNotExists().create();"""] + + vertex_script = ["Vertex vertex0 = graph.addVertex(label, 'lp', 'index', 0);"] + for index in range(1, length): + if not single_script and len(vertex_script) > 25: + queries.append("\n".join(vertex_script)) + vertex_script = [ + "Vertex vertex{pindex} = g.V().hasLabel('lp').has('index', {pindex}).next()".format( + pindex=index-1)] + + vertex_script.append(''' + Vertex vertex{vindex} = graph.addVertex(label, 'lp', 'index', {vindex}); + vertex{pindex}.addEdge('goesTo', vertex{vindex}, 'distance', 5); '''.format( + vindex=index, pindex=index - 1)) + + queries.append("\n".join(vertex_script)) + return queries + + @staticmethod + def classic(): + queries = [ALLOW_SCANS, + '''schema.propertyKey('name').Text().ifNotExists().create(); + schema.propertyKey('age').Int().ifNotExists().create(); + schema.propertyKey('lang').Text().ifNotExists().create(); + schema.propertyKey('weight').Float().ifNotExists().create(); + schema.vertexLabel('person').properties('name', 'age').ifNotExists().create(); + schema.vertexLabel('software').properties('name', 'lang').ifNotExists().create(); + schema.edgeLabel('created').properties('weight').connection('person', 'software').ifNotExists().create(); + schema.edgeLabel('created').connection('software', 'software').add(); + schema.edgeLabel('knows').properties('weight').connection('person', 'person').ifNotExists().create();''', + + '''Vertex marko = graph.addVertex(label, 'person', 'name', 'marko', 'age', 29); + Vertex vadas = graph.addVertex(label, 'person', 'name', 'vadas', 'age', 27); + Vertex lop = graph.addVertex(label, 'software', 'name', 'lop', 'lang', 'java'); + Vertex josh = graph.addVertex(label, 'person', 'name', 'josh', 'age', 32); + Vertex ripple = graph.addVertex(label, 'software', 'name', 'ripple', 'lang', 'java'); + Vertex peter = graph.addVertex(label, 'person', 'name', 'peter', 'age', 35); + Vertex carl = graph.addVertex(label, 'person', 'name', 'carl', 'age', 35); + marko.addEdge('knows', vadas, 'weight', 0.5f); + marko.addEdge('knows', josh, 'weight', 1.0f); + marko.addEdge('created', lop, 'weight', 0.4f); + josh.addEdge('created', ripple, 'weight', 1.0f); + josh.addEdge('created', lop, 'weight', 0.4f); + peter.addEdge('created', lop, 'weight', 0.2f);'''] + + return "\n".join(queries) + + @staticmethod + def multiple_fields(): + query_params = {} + queries= [ALLOW_SCANS, + '''schema.propertyKey('shortvalue').Smallint().ifNotExists().create(); + schema.vertexLabel('shortvertex').properties('shortvalue').ifNotExists().create(); + short s1 = 5000; graph.addVertex(label, "shortvertex", "shortvalue", s1); + schema.propertyKey('intvalue').Int().ifNotExists().create(); + schema.vertexLabel('intvertex').properties('intvalue').ifNotExists().create(); + int i1 = 1000000000; graph.addVertex(label, "intvertex", "intvalue", i1); + schema.propertyKey('intvalue2').Int().ifNotExists().create(); + schema.vertexLabel('intvertex2').properties('intvalue2').ifNotExists().create(); + Integer i2 = 100000000; graph.addVertex(label, "intvertex2", "intvalue2", i2); + schema.propertyKey('longvalue').Bigint().ifNotExists().create(); + schema.vertexLabel('longvertex').properties('longvalue').ifNotExists().create(); + long l1 = 9223372036854775807; graph.addVertex(label, "longvertex", "longvalue", l1); + schema.propertyKey('longvalue2').Bigint().ifNotExists().create(); + schema.vertexLabel('longvertex2').properties('longvalue2').ifNotExists().create(); + Long l2 = 100000000000000000L; graph.addVertex(label, "longvertex2", "longvalue2", l2); + schema.propertyKey('floatvalue').Float().ifNotExists().create(); + schema.vertexLabel('floatvertex').properties('floatvalue').ifNotExists().create(); + float f1 = 3.5f; graph.addVertex(label, "floatvertex", "floatvalue", f1); + schema.propertyKey('doublevalue').Double().ifNotExists().create(); + schema.vertexLabel('doublevertex').properties('doublevalue').ifNotExists().create(); + double d1 = 3.5e40; graph.addVertex(label, "doublevertex", "doublevalue", d1); + schema.propertyKey('doublevalue2').Double().ifNotExists().create(); + schema.vertexLabel('doublevertex2').properties('doublevalue2').ifNotExists().create(); + Double d2 = 3.5e40d; graph.addVertex(label, "doublevertex2", "doublevalue2", d2);'''] + + if DSE_VERSION >= Version('5.1'): + queries.append('''schema.propertyKey('datevalue1').Date().ifNotExists().create(); + schema.vertexLabel('datevertex1').properties('datevalue1').ifNotExists().create(); + schema.propertyKey('negdatevalue2').Date().ifNotExists().create(); + schema.vertexLabel('negdatevertex2').properties('negdatevalue2').ifNotExists().create();''') + + for i in range(1, 4): + queries.append('''schema.propertyKey('timevalue{0}').Time().ifNotExists().create(); + schema.vertexLabel('timevertex{0}').properties('timevalue{0}').ifNotExists().create();'''.format( + i)) + + queries.append('graph.addVertex(label, "datevertex1", "datevalue1", date1);') + query_params['date1'] = '1999-07-29' + + queries.append('graph.addVertex(label, "negdatevertex2", "negdatevalue2", date2);') + query_params['date2'] = '-1999-07-28' + + queries.append('graph.addVertex(label, "timevertex1", "timevalue1", time1);') + query_params['time1'] = '14:02' + queries.append('graph.addVertex(label, "timevertex2", "timevalue2", time2);') + query_params['time2'] = '14:02:20' + queries.append('graph.addVertex(label, "timevertex3", "timevalue3", time3);') + query_params['time3'] = '14:02:20.222' + + return queries, query_params + + @staticmethod + def large(): + query_parts = [''' + int size = 2000; + List ids = new ArrayList(); + schema.propertyKey('ts').Int().single().ifNotExists().create(); + schema.propertyKey('sin').Int().single().ifNotExists().create(); + schema.propertyKey('cos').Int().single().ifNotExists().create(); + schema.propertyKey('ii').Int().single().ifNotExists().create(); + schema.vertexLabel('lcg').properties('ts', 'sin', 'cos', 'ii').ifNotExists().create(); + schema.edgeLabel('linked').connection('lcg', 'lcg').ifNotExists().create(); + Vertex v = graph.addVertex(label, 'lcg'); + v.property("ts", 100001); + v.property("sin", 0); + v.property("cos", 1); + v.property("ii", 0); + ids.add(v.id()); + Random rand = new Random(); + for (int ii = 1; ii < size; ii++) { + v = graph.addVertex(label, 'lcg'); + v.property("ii", ii); + v.property("ts", 100001 + ii); + v.property("sin", Math.sin(ii/5.0)); + v.property("cos", Math.cos(ii/5.0)); + Vertex u = g.V(ids.get(rand.nextInt(ids.size()))).next(); + v.addEdge("linked", u); + ids.add(v.id()); + } + g.V().count();'''] + + return "\n".join(query_parts) + + @staticmethod + def address_book(): + p1 = "Point()" + p2 = "Point()" + if DSE_VERSION >= Version('5.1'): + p1 = "Point().withBounds(-100, -100, 100, 100)" + p2 = "Point().withGeoBounds()" + + queries = [ + ALLOW_SCANS, + "schema.propertyKey('name').Text().ifNotExists().create()", + "schema.propertyKey('pointPropWithBoundsWithSearchIndex').{}.ifNotExists().create()".format(p1), + "schema.propertyKey('pointPropWithBounds').{}.ifNotExists().create()".format(p1), + "schema.propertyKey('pointPropWithGeoBoundsWithSearchIndex').{}.ifNotExists().create()".format(p2), + "schema.propertyKey('pointPropWithGeoBounds').{}.ifNotExists().create()".format(p2), + "schema.propertyKey('city').Text().ifNotExists().create()", + "schema.propertyKey('state').Text().ifNotExists().create()", + "schema.propertyKey('description').Text().ifNotExists().create()", + "schema.vertexLabel('person').properties('name', 'city', 'state', 'description', 'pointPropWithBoundsWithSearchIndex', 'pointPropWithBounds', 'pointPropWithGeoBoundsWithSearchIndex', 'pointPropWithGeoBounds').ifNotExists().create()", + "schema.vertexLabel('person').index('searchPointWithBounds').secondary().by('pointPropWithBounds').ifNotExists().add()", + "schema.vertexLabel('person').index('searchPointWithGeoBounds').secondary().by('pointPropWithGeoBounds').ifNotExists().add()", + + "g.addV('person').property('name', 'Paul Thomas Joe').property('city', 'Rochester').property('state', 'MN').property('pointPropWithBoundsWithSearchIndex', Geo.point(-92.46295, 44.0234)).property('pointPropWithBounds', Geo.point(-92.46295, 44.0234)).property('pointPropWithGeoBoundsWithSearchIndex', Geo.point(-92.46295, 44.0234)).property('pointPropWithGeoBounds', Geo.point(-92.46295, 44.0234)).property('description', 'Lives by the hospital').next()", + "g.addV('person').property('name', 'George Bill Steve').property('city', 'Minneapolis').property('state', 'MN').property('pointPropWithBoundsWithSearchIndex', Geo.point(-93.266667, 44.093333)).property('pointPropWithBounds', Geo.point(-93.266667, 44.093333)).property('pointPropWithGeoBoundsWithSearchIndex', Geo.point(-93.266667, 44.093333)).property('pointPropWithGeoBounds', Geo.point(-93.266667, 44.093333)).property('description', 'A cold dude').next()", + "g.addV('person').property('name', 'James Paul Smith').property('city', 'Chicago').property('state', 'IL').property('pointPropWithBoundsWithSearchIndex', Geo.point(-87.684722, 41.836944)).property('description', 'Likes to hang out').next()", + "g.addV('person').property('name', 'Jill Alice').property('city', 'Atlanta').property('state', 'GA').property('pointPropWithBoundsWithSearchIndex', Geo.point(-84.39, 33.755)).property('description', 'Enjoys a nice cold coca cola').next()" + ] + + if not Version('5.0') <= DSE_VERSION < Version('5.1'): + queries.append("schema.vertexLabel('person').index('search').search().by('pointPropWithBoundsWithSearchIndex').withError(0.00001, 0.0).by('pointPropWithGeoBoundsWithSearchIndex').withError(0.00001, 0.0).ifNotExists().add()") + + return "\n".join(queries) + + +class CoreGraphFixtures(GraphFixtures): + + @staticmethod + def datatypes(): + data = ClassicGraphFixtures.datatypes() + del data['duration1'] + del data['duration2'] + + # Core Graphs only types + data["map1"] = ["mapOf(Text, Text)", {'test': 'test'}, None] + data["map2"] = ["mapOf(Text, Point)", {'test': Point(.5, .13)}, None] + data["map3"] = ["frozen(mapOf(Int, Varchar))", {42: 'test'}, None] + + data["list1"] = ["listOf(Text)", ['test', 'hello', 'world'], None] + data["list2"] = ["listOf(Int)", [42, 632, 32], None] + data["list3"] = ["listOf(Point)", [Point(.5, .13), Point(42.5, .13)], None] + data["list4"] = ["frozen(listOf(Int))", [42, 55, 33], None] + + data["set1"] = ["setOf(Text)", {'test', 'hello', 'world'}, None] + data["set2"] = ["setOf(Int)", {42, 632, 32}, None] + data["set3"] = ["setOf(Point)", {Point(.5, .13), Point(42.5, .13)}, None] + data["set4"] = ["frozen(setOf(Int))", {42, 55, 33}, None] + + data["tuple1"] = ["tupleOf(Int, Text)", (42, "world"), None] + data["tuple2"] = ["tupleOf(Int, tupleOf(Text, tupleOf(Text, Point)))", (42, ("world", ('this', Point(.5, .13)))), None] + data["tuple3"] = ["tupleOf(Int, tupleOf(Text, frozen(mapOf(Text, Text))))", (42, ("world", {'test': 'test'})), None] + data["tuple4"] = ["tupleOf(Int, tupleOf(Text, frozen(listOf(Int))))", (42, ("world", [65, 89])), None] + data["tuple5"] = ["tupleOf(Int, tupleOf(Text, frozen(setOf(Int))))", (42, ("world", {65, 55})), None] + data["tuple6"] = ["tupleOf(Int, tupleOf(Text, tupleOf(Text, LineString)))", + (42, ("world", ('this', LineString(((1.0, 2.0), (3.0, 4.0), (-89.0, 90.0)))))), None] + + data["tuple7"] = ["tupleOf(Int, tupleOf(Text, tupleOf(Text, Polygon)))", + (42, ("world", ('this', Polygon([(10.0, 10.0), (80.0, 10.0), (80., 88.0), (10., 89.0), (10., 10.0)], + [[(20., 20.0), (20., 30.0), (30., 30.0), (30., 20.0), (20., 20.0)], + [(40., 20.0), (40., 30.0), (50., 30.0), (50., 20.0), (40., 20.0)]])))), None] + data["dse_duration1"] = ["Duration()", Duration(42, 12, 10303312), None] + data["dse_duration2"] = ["Duration()", Duration(50, 32, 11), None] + + return data + + @staticmethod + def line(length, single_script=False): + queries = [""" + schema.vertexLabel('lp').ifNotExists().partitionBy('index', Int).create(); + schema.edgeLabel('goesTo').ifNotExists().from('lp').to('lp').property('distance', Int).create(); + """] + + vertex_script = ["g.addV('lp').property('index', 0).next();"] + for index in range(1, length): + if not single_script and len(vertex_script) > 25: + queries.append("\n".join(vertex_script)) + vertex_script = [] + + vertex_script.append(''' + g.addV('lp').property('index', {index}).next(); + g.V().hasLabel('lp').has('index', {pindex}).as('pp').V().hasLabel('lp').has('index', {index}).as('p'). + addE('goesTo').from('pp').to('p').property('distance', 5).next(); + '''.format( + index=index, pindex=index - 1)) + + queries.append("\n".join(vertex_script)) + return queries + + @staticmethod + def classic(): + queries = [ + ''' + schema.vertexLabel('person').ifNotExists().partitionBy('name', Text).property('age', Int).create(); + schema.vertexLabel('software')ifNotExists().partitionBy('name', Text).property('lang', Text).create(); + schema.edgeLabel('created').ifNotExists().from('person').to('software').property('weight', Double).create(); + schema.edgeLabel('knows').ifNotExists().from('person').to('person').property('weight', Double).create(); + ''', + + ''' + Vertex marko = g.addV('person').property('name', 'marko').property('age', 29).next(); + Vertex vadas = g.addV('person').property('name', 'vadas').property('age', 27).next(); + Vertex lop = g.addV('software').property('name', 'lop').property('lang', 'java').next(); + Vertex josh = g.addV('person').property('name', 'josh').property('age', 32).next(); + Vertex peter = g.addV('person').property('name', 'peter').property('age', 35).next(); + Vertex carl = g.addV('person').property('name', 'carl').property('age', 35).next(); + Vertex ripple = g.addV('software').property('name', 'ripple').property('lang', 'java').next(); + + // TODO, switch to VertexReference and use v.id() + g.V().hasLabel('person').has('name', 'vadas').as('v').V().hasLabel('person').has('name', 'marko').as('m').addE('knows').from('m').to('v').property('weight', 0.5d).next(); + g.V().hasLabel('person').has('name', 'josh').as('j').V().hasLabel('person').has('name', 'marko').as('m').addE('knows').from('m').to('j').property('weight', 1.0d).next(); + g.V().hasLabel('software').has('name', 'lop').as('l').V().hasLabel('person').has('name', 'marko').as('m').addE('created').from('m').to('l').property('weight', 0.4d).next(); + g.V().hasLabel('software').has('name', 'ripple').as('r').V().hasLabel('person').has('name', 'josh').as('j').addE('created').from('j').to('r').property('weight', 1.0d).next(); + g.V().hasLabel('software').has('name', 'lop').as('l').V().hasLabel('person').has('name', 'josh').as('j').addE('created').from('j').to('l').property('weight', 0.4d).next(); + g.V().hasLabel('software').has('name', 'lop').as('l').V().hasLabel('person').has('name', 'peter').as('p').addE('created').from('p').to('l').property('weight', 0.2d).next(); + + '''] + + return queries + + @staticmethod + def multiple_fields(): + ## no generic test currently needs this + raise NotImplementedError() + + @staticmethod + def large(): + query_parts = [ + ''' + schema.vertexLabel('lcg').ifNotExists().partitionBy('ts', Int).property('sin', Double). + property('cos', Double).property('ii', Int).create(); + schema.edgeLabel('linked').ifNotExists().from('lcg').to('lcg').create(); + ''', + + ''' + int size = 2000; + List ids = new ArrayList(); + v = g.addV('lcg').property('ts', 100001).property('sin', 0d).property('cos', 1d).property('ii', 0).next(); + ids.add(v.id()); + Random rand = new Random(); + for (int ii = 1; ii < size; ii++) { + v = g.addV('lcg').property('ts', 100001 + ii).property('sin', Math.sin(ii/5.0)).property('cos', Math.cos(ii/5.0)).property('ii', ii).next(); + + uid = ids.get(rand.nextInt(ids.size())) + g.V(v.id()).as('v').V(uid).as('u').addE('linked').from('v').to('u').next(); + ids.add(v.id()); + } + g.V().count();''' + ] + + return query_parts + + @staticmethod + def address_book(): + queries = [ + "schema.vertexLabel('person').ifNotExists().partitionBy('name', Text)." + "property('pointPropWithBoundsWithSearchIndex', Point)." + "property('pointPropWithBounds', Point)." + "property('pointPropWithGeoBoundsWithSearchIndex', Point)." + "property('pointPropWithGeoBounds', Point)." + "property('city', Text)." + "property('state', Text)." + "property('description', Text).create()", + "schema.vertexLabel('person').searchIndex().by('name').by('pointPropWithBounds').by('pointPropWithGeoBounds').by('description').asText().create()", + "g.addV('person').property('name', 'Paul Thomas Joe').property('city', 'Rochester').property('state', 'MN').property('pointPropWithBoundsWithSearchIndex', Geo.point(-92.46295, 44.0234)).property('pointPropWithBounds', Geo.point(-92.46295, 44.0234)).property('pointPropWithGeoBoundsWithSearchIndex', Geo.point(-92.46295, 44.0234)).property('pointPropWithGeoBounds', Geo.point(-92.46295, 44.0234)).property('description', 'Lives by the hospital').next()", + "g.addV('person').property('name', 'George Bill Steve').property('city', 'Minneapolis').property('state', 'MN').property('pointPropWithBoundsWithSearchIndex', Geo.point(-93.266667, 44.093333)).property('pointPropWithBounds', Geo.point(-93.266667, 44.093333)).property('pointPropWithGeoBoundsWithSearchIndex', Geo.point(-93.266667, 44.093333)).property('pointPropWithGeoBounds', Geo.point(-93.266667, 44.093333)).property('description', 'A cold dude').next()", + "g.addV('person').property('name', 'James Paul Smith').property('city', 'Chicago').property('state', 'IL').property('pointPropWithBoundsWithSearchIndex', Geo.point(-87.684722, 41.836944)).property('description', 'Likes to hang out').next()", + "g.addV('person').property('name', 'Jill Alice').property('city', 'Atlanta').property('state', 'GA').property('pointPropWithBoundsWithSearchIndex', Geo.point(-84.39, 33.755)).property('description', 'Enjoys a nice cold coca cola').next()" + ] + + if not Version('5.0') <= DSE_VERSION < Version('5.1'): + queries.append("schema.vertexLabel('person').searchIndex().by('pointPropWithBoundsWithSearchIndex').by('pointPropWithGeoBounds')" + ".by('pointPropWithGeoBoundsWithSearchIndex').create()") + + return queries + + +def validate_classic_vertex(test, vertex): + vertex_props = vertex.properties.keys() + test.assertEqual(len(vertex_props), 2) + test.assertIn('name', vertex_props) + test.assertTrue('lang' in vertex_props or 'age' in vertex_props) + + +def validate_classic_vertex_return_type(test, vertex): + validate_generic_vertex_result_type(vertex) + vertex_props = vertex.properties + test.assertIn('name', vertex_props) + test.assertTrue('lang' in vertex_props or 'age' in vertex_props) + + +def validate_generic_vertex_result_type(test, vertex): + test.assertIsInstance(vertex, Vertex) + for attr in ('id', 'type', 'label', 'properties'): + test.assertIsNotNone(getattr(vertex, attr)) + + +def validate_classic_edge_properties(test, edge_properties): + test.assertEqual(len(edge_properties.keys()), 1) + test.assertIn('weight', edge_properties) + test.assertIsInstance(edge_properties, dict) + + +def validate_classic_edge(test, edge): + validate_generic_edge_result_type(test, edge) + validate_classic_edge_properties(test, edge.properties) + + +def validate_line_edge(test, edge): + validate_generic_edge_result_type(test, edge) + edge_props = edge.properties + test.assertEqual(len(edge_props.keys()), 1) + test.assertIn('distance', edge_props) + + +def validate_generic_edge_result_type(test, edge): + test.assertIsInstance(edge, Edge) + for attr in ('properties', 'outV', 'outVLabel', 'inV', 'inVLabel', 'label', 'type', 'id'): + test.assertIsNotNone(getattr(edge, attr)) + + +def validate_path_result_type(test, path): + test.assertIsInstance(path, Path) + test.assertIsNotNone(path.labels) + for obj in path.objects: + if isinstance(obj, Edge): + validate_classic_edge(test, obj) + elif isinstance(obj, Vertex): + validate_classic_vertex(test, obj) + else: + test.fail("Invalid object found in path " + str(object.type)) + + +class GraphTestConfiguration(object): + """Possible Configurations: + ClassicGraphSchema: + graphson1 + graphson2 + graphson3 + + CoreGraphSchema + graphson3 + """ + + @classmethod + def schemas(cls): + schemas = [ClassicGraphSchema] + if DSE_VERSION >= Version("6.8"): + schemas.append(CoreGraphSchema) + return schemas + + @classmethod + def graphson_versions(cls): + graphson_versions = [GraphProtocol.GRAPHSON_1_0] + if DSE_VERSION >= Version("6.0"): + graphson_versions.append(GraphProtocol.GRAPHSON_2_0) + if DSE_VERSION >= Version("6.8"): + graphson_versions.append(GraphProtocol.GRAPHSON_3_0) + return graphson_versions + + @classmethod + def schema_configurations(cls, schema=None): + schemas = cls.schemas() if schema is None else [schema] + configurations = [] + for s in schemas: + configurations.append(s) + + return configurations + + @classmethod + def configurations(cls, schema=None, graphson=None): + schemas = cls.schemas() if schema is None else [schema] + graphson_versions = cls.graphson_versions() if graphson is None else [graphson] + + configurations = [] + for s in schemas: + for g in graphson_versions: + if s is CoreGraphSchema and g != GraphProtocol.GRAPHSON_3_0: + continue + configurations.append((s, g)) + + return configurations + + @staticmethod + def _make_graph_schema_test_method(func, schema): + def test_input(self): + self.setup_graph(schema) + try: + func(self, schema) + except: + raise + finally: + self.teardown_graph(schema) + + schema_name = 'classic' if schema is ClassicGraphSchema else 'core' + test_input.__name__ = '{func}_{schema}'.format( + func=func.__name__.lstrip('_'), schema=schema_name) + return test_input + + @staticmethod + def _make_graph_test_method(func, schema, graphson): + def test_input(self): + self.setup_graph(schema) + try: + func(self, schema, graphson) + except: + raise + finally: + self.teardown_graph(schema) + + graphson_name = 'graphson1' + if graphson == GraphProtocol.GRAPHSON_2_0: + graphson_name = 'graphson2' + elif graphson == GraphProtocol.GRAPHSON_3_0: + graphson_name = 'graphson3' + + schema_name = 'classic' if schema is ClassicGraphSchema else 'core' + + # avoid keyspace name too long issue + if DSE_VERSION < Version('6.7'): + schema_name = schema_name[0] + graphson_name = 'g' + graphson_name[-1] + + test_input.__name__ = '{func}_{schema}_{graphson}'.format( + func=func.__name__.lstrip('_'), schema=schema_name, graphson=graphson_name) + return test_input + + @classmethod + def generate_tests(cls, schema=None, graphson=None, traversal=False): + """Generate tests for a graph configuration""" + def decorator(klass): + if DSE_VERSION: + predicate = inspect.isfunction + for name, func in inspect.getmembers(klass, predicate=predicate): + if not name.startswith('_test'): + continue + for _schema, _graphson in cls.configurations(schema, graphson): + if traversal and _graphson == GraphProtocol.GRAPHSON_1_0: + continue + test_input = cls._make_graph_test_method(func, _schema, _graphson) + log.debug("Generated test '{}.{}'".format(klass.__name__, test_input.__name__)) + setattr(klass, test_input.__name__, test_input) + return klass + + return decorator + + @classmethod + def generate_schema_tests(cls, schema=None): + """Generate schema tests for a graph configuration""" + def decorator(klass): + if DSE_VERSION: + predicate = inspect.isfunction + for name, func in inspect.getmembers(klass, predicate=predicate): + if not name.startswith('_test'): + continue + for _schema in cls.schema_configurations(schema): + test_input = cls._make_graph_schema_test_method(func, _schema) + log.debug("Generated test '{}.{}'".format(klass.__name__, test_input.__name__)) + setattr(klass, test_input.__name__, test_input) + return klass + + return decorator + + +class VertexLabel(object): + """ + Helper that represents a new VertexLabel: + + VertexLabel(['Int()', 'Float()']) # a vertex with 2 properties named property1 and property2 + VertexLabel([('int1', 'Int()'), 'Float()']) # a vertex with 2 properties named int1 and property1 + """ + + id = 0 + label = None + properties = None + + def __init__(self, properties): + VertexLabel.id += 1 + self.id = VertexLabel.id + self.label = "vertex{}".format(self.id) + self.properties = {'pkid': self.id} + property_count = 0 + for p in properties: + if isinstance(p, tuple): + name, typ = p + else: + property_count += 1 + name = "property-v{}-{}".format(self.id, property_count) + typ = p + self.properties[name] = typ + + @property + def non_pk_properties(self): + return {p: v for p, v in self.properties.items() if p != 'pkid'} + + +class GraphSchema(object): + + has_geo_bounds = DSE_VERSION and DSE_VERSION >= Version('5.1') + fixtures = GraphFixtures + + @classmethod + def sanitize_type(cls, typ): + if typ.lower().startswith("point"): + return cls.sanitize_point_type() + elif typ.lower().startswith("line"): + return cls.sanitize_line_type() + elif typ.lower().startswith("poly"): + return cls.sanitize_polygon_type() + else: + return typ + + @classmethod + def sanitize_point_type(cls): + return "Point().withGeoBounds()" if cls.has_geo_bounds else "Point()" + + @classmethod + def sanitize_line_type(cls): + return "Linestring().withGeoBounds()" if cls.has_geo_bounds else "Linestring()" + + @classmethod + def sanitize_polygon_type(cls): + return "Polygon().withGeoBounds()" if cls.has_geo_bounds else "Polygon()" + + @staticmethod + def drop_graph(session, graph_name): + ks = list(session.execute( + "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = '{}';".format(graph_name))) + if not ks: + return + + try: + session.execute_graph('system.graph(name).drop()', {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + except: + pass + + @staticmethod + def create_graph(session, graph_name): + raise NotImplementedError() + + @staticmethod + def clear(session): + pass + + @staticmethod + def create_vertex_label(session, vertex_label, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + raise NotImplementedError() + + @staticmethod + def add_vertex(session, vertex_label, name, value, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + raise NotImplementedError() + + @classmethod + def ensure_properties(cls, session, obj, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + if not isinstance(obj, (Vertex, Edge)): + return + + # This pre-processing is due to a change in TinkerPop + # properties are not returned automatically anymore + # with some queries. + if not obj.properties: + if isinstance(obj, Edge): + obj.properties = {} + for p in cls.get_edge_properties(session, obj, execution_profile=execution_profile): + obj.properties.update(p) + elif isinstance(obj, Vertex): + obj.properties = { + p.label: p + for p in cls.get_vertex_properties(session, obj, execution_profile=execution_profile) + } + + @staticmethod + def get_vertex_properties(session, vertex, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + return session.execute_graph("g.V(vertex_id).properties().toList()", {'vertex_id': vertex.id}, + execution_profile=execution_profile) + + @staticmethod + def get_edge_properties(session, edge, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + v = session.execute_graph("g.E(edge_id).properties().toList()", {'edge_id': edge.id}, + execution_profile=execution_profile) + return v + + +class ClassicGraphSchema(GraphSchema): + + fixtures = ClassicGraphFixtures + + @staticmethod + def create_graph(session, graph_name): + session.execute_graph(CREATE_CLASSIC_GRAPH, {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + wait_for_graph_inserted(session, graph_name) + + @staticmethod + def clear(session): + session.execute_graph('schema.clear()') + + @classmethod + def create_vertex_label(cls, session, vertex_label, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + statements = ["schema.propertyKey('pkid').Int().ifNotExists().create();"] + for k, v in vertex_label.non_pk_properties.items(): + typ = cls.sanitize_type(v) + statements.append("schema.propertyKey('{name}').{type}.create();".format( + name=k, type=typ + )) + + statements.append("schema.vertexLabel('{label}').partitionKey('pkid').properties(".format( + label=vertex_label.label)) + property_names = [name for name in vertex_label.non_pk_properties.keys()] + statements.append(", ".join(["'{}'".format(p) for p in property_names])) + statements.append(").create();") + + to_run = "\n".join(statements) + session.execute_graph(to_run, execution_profile=execution_profile) + + @staticmethod + def add_vertex(session, vertex_label, name, value, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + statement = "g.addV('{label}').property('pkid', {pkid}).property('{property_name}', val);".format( + pkid=vertex_label.id, label=vertex_label.label, property_name=name) + parameters = {'val': value} + return session.execute_graph(statement, parameters, execution_profile=execution_profile) + + +class CoreGraphSchema(GraphSchema): + + fixtures = CoreGraphFixtures + + @classmethod + def sanitize_type(cls, typ): + typ = super(CoreGraphSchema, cls).sanitize_type(typ) + return typ.replace('()', '') + + @classmethod + def sanitize_point_type(cls): + return "Point" + + @classmethod + def sanitize_line_type(cls): + return "LineString" + + @classmethod + def sanitize_polygon_type(cls): + return "Polygon" + + @staticmethod + def create_graph(session, graph_name): + session.execute_graph('system.graph(name).create()', {'name': graph_name}, + execution_profile=EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT) + wait_for_graph_inserted(session, graph_name) + + @classmethod + def create_vertex_label(cls, session, vertex_label, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + statements = ["schema.vertexLabel('{label}').partitionBy('pkid', Int)".format( + label=vertex_label.label)] + + for name, typ in vertex_label.non_pk_properties.items(): + typ = cls.sanitize_type(typ) + statements.append(".property('{name}', {type})".format(name=name, type=typ)) + statements.append(".create();") + + to_run = "\n".join(statements) + session.execute_graph(to_run, execution_profile=execution_profile) + + @staticmethod + def add_vertex(session, vertex_label, name, value, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): + statement = "g.addV('{label}').property('pkid', {pkid}).property('{property_name}', val);".format( + pkid=vertex_label.id, label=vertex_label.label, property_name=name) + parameters = {'val': value} + return session.execute_graph(statement, parameters, execution_profile=execution_profile) diff --git a/tests/integration/advanced/graph/fluent/__init__.py b/tests/integration/advanced/graph/fluent/__init__.py new file mode 100644 index 0000000000..1c07cd46c0 --- /dev/null +++ b/tests/integration/advanced/graph/fluent/__init__.py @@ -0,0 +1,720 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import datetime +import time +from collections import namedtuple +from packaging.version import Version + +from cassandra.datastax.graph.fluent import DseGraph +from cassandra.graph import VertexProperty, GraphProtocol +from cassandra.util import Point, Polygon, LineString + +from gremlin_python.process.graph_traversal import GraphTraversal, GraphTraversalSource +from gremlin_python.process.traversal import P +from gremlin_python.structure.graph import Edge as TravEdge +from gremlin_python.structure.graph import Vertex as TravVertex, VertexProperty as TravVertexProperty + +from tests.util import wait_until_not_raised +from tests.integration import DSE_VERSION +from tests.integration.advanced.graph import ( + GraphUnitTestCase, ClassicGraphSchema, CoreGraphSchema, + VertexLabel) +from tests.integration import requiredse + +import unittest + + +import ipaddress + + +def check_equality_base(testcase, original, read_value): + if isinstance(original, float): + testcase.assertAlmostEqual(original, read_value, delta=.01) + elif isinstance(original, ipaddress.IPv4Address): + testcase.assertAlmostEqual(original, ipaddress.IPv4Address(read_value)) + elif isinstance(original, ipaddress.IPv6Address): + testcase.assertAlmostEqual(original, ipaddress.IPv6Address(read_value)) + else: + testcase.assertEqual(original, read_value) + + +def create_traversal_profiles(cluster, graph_name): + ep_graphson2 = DseGraph().create_execution_profile( + graph_name, graph_protocol=GraphProtocol.GRAPHSON_2_0) + ep_graphson3 = DseGraph().create_execution_profile( + graph_name, graph_protocol=GraphProtocol.GRAPHSON_3_0) + + cluster.add_execution_profile('traversal_graphson2', ep_graphson2) + cluster.add_execution_profile('traversal_graphson3', ep_graphson3) + + return ep_graphson2, ep_graphson3 + + +class _AbstractTraversalTest(GraphUnitTestCase): + + def setUp(self): + super(_AbstractTraversalTest, self).setUp() + self.ep_graphson2, self.ep_graphson3 = create_traversal_profiles(self.cluster, self.graph_name) + + def _test_basic_query(self, schema, graphson): + """ + Test to validate that basic graph queries works + + Creates a simple classic tinkerpot graph, and attempts to preform a basic query + using Tinkerpop's GLV with both explicit and implicit execution + ensuring that each one is correct. See reference graph here + http://www.tinkerpop.com/docs/3.0.0.M1/ + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result graph should generate and all vertices and edge results should be + + @test_category dse graph + """ + + g = self.fetch_traversal_source(graphson) + self.execute_graph(schema.fixtures.classic(), graphson) + traversal = g.V().has('name', 'marko').out('knows').values('name') + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn('vadas', results_list) + self.assertIn('josh', results_list) + + def _test_classic_graph(self, schema, graphson): + """ + Test to validate that basic graph generation, and vertex and edges are surfaced correctly + + Creates a simple classic tinkerpot graph, and iterates over the the vertices and edges + using Tinkerpop's GLV with both explicit and implicit execution + ensuring that each one iscorrect. See reference graph here + http://www.tinkerpop.com/docs/3.0.0.M1/ + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result graph should generate and all vertices and edge results should be + + @test_category dse graph + """ + + self.execute_graph(schema.fixtures.classic(), graphson) + ep = self.get_execution_profile(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V() + vert_list = self.execute_traversal(traversal, graphson) + + for vertex in vert_list: + schema.ensure_properties(self.session, vertex, execution_profile=ep) + self._validate_classic_vertex(g, vertex) + traversal = g.E() + edge_list = self.execute_traversal(traversal, graphson) + for edge in edge_list: + schema.ensure_properties(self.session, edge, execution_profile=ep) + self._validate_classic_edge(g, edge) + + def _test_graph_classic_path(self, schema, graphson): + """ + Test to validate that the path version of the result type is generated correctly. It also + tests basic path results as that is not covered elsewhere + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result path object should be unpacked correctly including all nested edges and vertices + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().hasLabel('person').has('name', 'marko').as_('a').outE('knows').inV().as_('c', 'd').outE('created').as_('e', 'f', 'g').inV().path() + path_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(path_list), 2) + for path in path_list: + self._validate_path_result_type(g, path) + + def _test_range_query(self, schema, graphson): + """ + Test to validate range queries are handled correctly. + + Creates a very large line graph script and executes it. Then proceeds to to a range + limited query against it, and ensure that the results are formated correctly and that + the result set is properly sized. + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result result set should be properly formated and properly sized + + @test_category dse graph + """ + + self.execute_graph(schema.fixtures.line(150), graphson) + ep = self.get_execution_profile(graphson) + g = self.fetch_traversal_source(graphson) + + traversal = g.E().range(0, 10) + edges = self.execute_traversal(traversal, graphson) + self.assertEqual(len(edges), 10) + for edge in edges: + schema.ensure_properties(self.session, edge, execution_profile=ep) + self._validate_line_edge(g, edge) + + def _test_result_types(self, schema, graphson): + """ + Test to validate that the edge and vertex version of results are constructed correctly. + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result edge/vertex result types should be unpacked correctly. + @test_category dse graph + """ + self.execute_graph(schema.fixtures.line(150), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V() + vertices = self.execute_traversal(traversal, graphson) + for vertex in vertices: + self._validate_type(g, vertex) + + def _test_large_result_set(self, schema, graphson): + """ + Test to validate that large result sets return correctly. + + Creates a very large graph. Ensures that large result sets are handled appropriately. + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result when limits of result sets are hit errors should be surfaced appropriately + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.large(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V() + vertices = self.execute_traversal(traversal, graphson) + for vertex in vertices: + self._validate_generic_vertex_result_type(g, vertex) + + def _test_vertex_meta_properties(self, schema, graphson): + """ + Test verifying vertex property properties + + @since 1.0.0 + @jira_ticket PYTHON-641 + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('skipped because multiple properties are only supported with classic graphs') + + s = self.session + s.execute_graph("schema.propertyKey('k0').Text().ifNotExists().create();") + s.execute_graph("schema.propertyKey('k1').Text().ifNotExists().create();") + s.execute_graph("schema.propertyKey('key').Text().properties('k0', 'k1').ifNotExists().create();") + s.execute_graph("schema.vertexLabel('MLP').properties('key').ifNotExists().create();") + s.execute_graph("schema.config().option('graph.allow_scan').set('true');") + v = s.execute_graph('''v = graph.addVertex('MLP') + v.property('key', 'meta_prop', 'k0', 'v0', 'k1', 'v1') + v''')[0] + + g = self.fetch_traversal_source(graphson) + + traversal = g.V() + # This should contain key, and value where value is a property + # This should be a vertex property and should contain sub properties + results = self.execute_traversal(traversal, graphson) + self._validate_meta_property(g, results[0]) + + def _test_vertex_multiple_properties(self, schema, graphson): + """ + Test verifying vertex property form for various Cardinality + + All key types are encoded as a list, regardless of cardinality + + Single cardinality properties have only one value -- the last one added + + Default is single (this is config dependent) + + @since 1.0.0 + @jira_ticket PYTHON-641 + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('skipped because multiple properties are only supported with classic graphs') + + s = self.session + s.execute_graph('''Schema schema = graph.schema(); + schema.propertyKey('mult_key').Text().multiple().ifNotExists().create(); + schema.propertyKey('single_key').Text().single().ifNotExists().create(); + schema.vertexLabel('MPW1').properties('mult_key').ifNotExists().create(); + schema.vertexLabel('MPW2').properties('mult_key').ifNotExists().create(); + schema.vertexLabel('SW1').properties('single_key').ifNotExists().create();''') + + mpw1v = s.execute_graph('''v = graph.addVertex('MPW1') + v.property('mult_key', 'value') + v''')[0] + + mpw2v = s.execute_graph('''g.addV('MPW2').property('mult_key', 'value0').property('mult_key', 'value1')''')[0] + + g = self.fetch_traversal_source(graphson) + traversal = g.V(mpw1v.id).properties() + + vertex_props = self.execute_traversal(traversal, graphson) + + self.assertEqual(len(vertex_props), 1) + + self.assertEqual(self.fetch_key_from_prop(vertex_props[0]), "mult_key") + self.assertEqual(vertex_props[0].value, "value") + + # multiple_with_two_values + #v = s.execute_graph('''g.addV(label, 'MPW2', 'mult_key', 'value0', 'mult_key', 'value1')''')[0] + traversal = g.V(mpw2v.id).properties() + + vertex_props = self.execute_traversal(traversal, graphson) + + self.assertEqual(len(vertex_props), 2) + self.assertEqual(self.fetch_key_from_prop(vertex_props[0]), 'mult_key') + self.assertEqual(self.fetch_key_from_prop(vertex_props[1]), 'mult_key') + self.assertEqual(vertex_props[0].value, 'value0') + self.assertEqual(vertex_props[1].value, 'value1') + + # single_with_one_value + v = s.execute_graph('''v = graph.addVertex('SW1') + v.property('single_key', 'value') + v''')[0] + traversal = g.V(v.id).properties() + vertex_props = self.execute_traversal(traversal, graphson) + self.assertEqual(len(vertex_props), 1) + self.assertEqual(self.fetch_key_from_prop(vertex_props[0]), "single_key") + self.assertEqual(vertex_props[0].value, "value") + + def should_parse_meta_properties(self): + g = self.fetch_traversal_source() + g.addV("meta_v").property("meta_prop", "hello", "sub_prop", "hi", "sub_prop2", "hi2") + + def _test_all_graph_types_with_schema(self, schema, graphson): + """ + Exhaustively goes through each type that is supported by dse_graph. + creates a vertex for each type using a dse-tinkerpop traversal, + It then attempts to fetch it from the server and compares it to what was inserted + Prime the graph with the correct schema first + + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result inserted objects are equivalent to those retrieved + + @test_category dse graph + """ + self._write_and_read_data_types(schema, graphson) + + def _test_all_graph_types_without_schema(self, schema, graphson): + """ + Exhaustively goes through each type that is supported by dse_graph. + creates a vertex for each type using a dse-tinkerpop traversal, + It then attempts to fetch it from the server and compares it to what was inserted + Do not prime the graph with the correct schema first + @since 1.0.0 + @jira_ticket PYTHON-641 + @expected_result inserted objects are equivalent to those retrieved + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('schema-less is only for classic graphs') + self._write_and_read_data_types(schema, graphson, use_schema=False) + + def _test_dsl(self, schema, graphson): + """ + The test creates a SocialTraversal and a SocialTraversalSource as part of + a DSL. Then calls it's method and checks the results to verify + we have the expected results + + @since @since 1.1.0a1 + @jira_ticket PYTHON-790 + @expected_result only the vertex corresponding to marko is in the result + + @test_category dse graph + """ + class SocialTraversal(GraphTraversal): + def knows(self, person_name): + return self.out("knows").hasLabel("person").has("name", person_name).in_() + + class SocialTraversalSource(GraphTraversalSource): + def __init__(self, *args, **kwargs): + super(SocialTraversalSource, self).__init__(*args, **kwargs) + self.graph_traversal = SocialTraversal + + def people(self, *names): + return self.get_graph_traversal().V().has("name", P.within(*names)) + + self.execute_graph(schema.fixtures.classic(), graphson) + if schema is CoreGraphSchema: + self.execute_graph(""" + schema.edgeLabel('knows').from('person').to('person').materializedView('person__knows__person_by_in_name'). + ifNotExists().partitionBy('in_name').clusterBy('out_name', Asc).create() + """, graphson) + time.sleep(1) # give some time to the MV to be populated + g = self.fetch_traversal_source(graphson, traversal_class=SocialTraversalSource) + + traversal = g.people("marko", "albert").knows("vadas") + results = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results), 1) + only_vertex = results[0] + schema.ensure_properties(self.session, only_vertex, + execution_profile=self.get_execution_profile(graphson)) + self._validate_classic_vertex(g, only_vertex) + + def _test_bulked_results(self, schema, graphson): + """ + Send a query expecting a bulked result and the driver "undoes" + the bulk and returns the expected list + + @since 1.1.0a1 + @jira_ticket PYTHON-771 + @expected_result the expanded list + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + g = self.fetch_traversal_source(graphson) + barrier_traversal = g.E().label().barrier() + results = self.execute_traversal(barrier_traversal, graphson) + self.assertEqual(sorted(["created", "created", "created", "created", "knows", "knows"]), sorted(results)) + + def _test_udt_with_classes(self, schema, graphson): + class Address(object): + + def __init__(self, address, city, state): + self.address = address + self.city = city + self.state = state + + def __eq__(self, other): + return self.address == other.address and self.city == other.city and self.state == other.state + + class AddressWithTags(object): + + def __init__(self, address, city, state, tags): + self.address = address + self.city = city + self.state = state + self.tags = tags + + def __eq__(self, other): + return (self.address == other.address and self.city == other.city + and self.state == other.state and self.tags == other.tags) + + class ComplexAddress(object): + + def __init__(self, address, address_tags, city, state, props): + self.address = address + self.address_tags = address_tags + self.city = city + self.state = state + self.props = props + + def __eq__(self, other): + return (self.address == other.address and self.address_tags == other.address_tags + and self.city == other.city and self.state == other.state + and self.props == other.props) + + class ComplexAddressWithOwners(object): + + def __init__(self, address, address_tags, city, state, props, owners): + self.address = address + self.address_tags = address_tags + self.city = city + self.state = state + self.props = props + self.owners = owners + + def __eq__(self, other): + return (self.address == other.address and self.address_tags == other.address_tags + and self.city == other.city and self.state == other.state + and self.props == other.props and self.owners == other.owners) + + self.__test_udt(schema, graphson, Address, AddressWithTags, ComplexAddress, ComplexAddressWithOwners) + + def _test_udt_with_namedtuples(self, schema, graphson): + AddressTuple = namedtuple('Address', ('address', 'city', 'state')) + AddressWithTagsTuple = namedtuple('AddressWithTags', ('address', 'city', 'state', 'tags')) + ComplexAddressTuple = namedtuple('ComplexAddress', ('address', 'address_tags', 'city', 'state', 'props')) + ComplexAddressWithOwnersTuple = namedtuple('ComplexAddressWithOwners', ('address', 'address_tags', 'city', + 'state', 'props', 'owners')) + + self.__test_udt(schema, graphson, AddressTuple, AddressWithTagsTuple, + ComplexAddressTuple, ComplexAddressWithOwnersTuple) + + def _write_and_read_data_types(self, schema, graphson, use_schema=True): + g = self.fetch_traversal_source(graphson) + ep = self.get_execution_profile(graphson) + for data in schema.fixtures.datatypes().values(): + typ, value, deserializer = data + vertex_label = VertexLabel([typ]) + property_name = next(iter(vertex_label.non_pk_properties.keys())) + if use_schema or schema is CoreGraphSchema: + schema.create_vertex_label(self.session, vertex_label, execution_profile=ep) + + write_traversal = g.addV(str(vertex_label.label)).property('pkid', vertex_label.id).\ + property(property_name, value) + self.execute_traversal(write_traversal, graphson) + + read_traversal = g.V().hasLabel(str(vertex_label.label)).has(property_name).properties() + results = self.execute_traversal(read_traversal, graphson) + + for result in results: + if result.label == 'pkid': + continue + self._check_equality(g, value, result.value) + + def __test_udt(self, schema, graphson, address_class, address_with_tags_class, + complex_address_class, complex_address_with_owners_class): + if schema is not CoreGraphSchema or DSE_VERSION < Version('6.8'): + raise unittest.SkipTest("Graph UDT is only supported with DSE 6.8+ and Core graphs.") + + ep = self.get_execution_profile(graphson) + + Address = address_class + AddressWithTags = address_with_tags_class + ComplexAddress = complex_address_class + ComplexAddressWithOwners = complex_address_with_owners_class + + # setup udt + self.session.execute_graph(""" + schema.type('address').property('address', Text).property('city', Text).property('state', Text).create(); + schema.type('addressTags').property('address', Text).property('city', Text).property('state', Text). + property('tags', setOf(Text)).create(); + schema.type('complexAddress').property('address', Text).property('address_tags', frozen(typeOf('addressTags'))). + property('city', Text).property('state', Text).property('props', mapOf(Text, Int)).create(); + schema.type('complexAddressWithOwners').property('address', Text). + property('address_tags', frozen(typeOf('addressTags'))). + property('city', Text).property('state', Text).property('props', mapOf(Text, Int)). + property('owners', frozen(listOf(tupleOf(Text, Int)))).create(); + """, execution_profile=ep) + + # wait max 10 seconds to get the UDT discovered. + wait_until_not_raised( + lambda: self.session.cluster.register_user_type(self.graph_name, 'address', Address), + 1, 10) + wait_until_not_raised( + lambda: self.session.cluster.register_user_type(self.graph_name, 'addressTags', AddressWithTags), + 1, 10) + wait_until_not_raised( + lambda: self.session.cluster.register_user_type(self.graph_name, 'complexAddress', ComplexAddress), + 1, 10) + wait_until_not_raised( + lambda: self.session.cluster.register_user_type(self.graph_name, 'complexAddressWithOwners', ComplexAddressWithOwners), + 1, 10) + + data = { + "udt1": ["typeOf('address')", Address('1440 Rd Smith', 'Quebec', 'QC')], + "udt2": ["tupleOf(typeOf('address'), Text)", (Address('1440 Rd Smith', 'Quebec', 'QC'), 'hello')], + "udt3": ["tupleOf(frozen(typeOf('address')), Text)", (Address('1440 Rd Smith', 'Quebec', 'QC'), 'hello')], + "udt4": ["tupleOf(tupleOf(Int, typeOf('address')), Text)", + ((42, Address('1440 Rd Smith', 'Quebec', 'QC')), 'hello')], + "udt5": ["tupleOf(tupleOf(Int, typeOf('addressTags')), Text)", + ((42, AddressWithTags('1440 Rd Smith', 'Quebec', 'QC', {'t1', 't2'})), 'hello')], + "udt6": ["tupleOf(tupleOf(Int, typeOf('complexAddress')), Text)", + ((42, ComplexAddress('1440 Rd Smith', + AddressWithTags('1440 Rd Smith', 'Quebec', 'QC', {'t1', 't2'}), + 'Quebec', 'QC', {'p1': 42, 'p2': 33})), 'hello')], + "udt7": ["tupleOf(tupleOf(Int, frozen(typeOf('complexAddressWithOwners'))), Text)", + ((42, ComplexAddressWithOwners( + '1440 Rd Smith', + AddressWithTags('1440 CRd Smith', 'Quebec', 'QC', {'t1', 't2'}), + 'Quebec', 'QC', {'p1': 42, 'p2': 33}, [('Mike', 43), ('Gina', 39)]) + ), 'hello')] + } + + g = self.fetch_traversal_source(graphson) + for typ, value in data.values(): + vertex_label = VertexLabel([typ]) + property_name = next(iter(vertex_label.non_pk_properties.keys())) + schema.create_vertex_label(self.session, vertex_label, execution_profile=ep) + + write_traversal = g.addV(str(vertex_label.label)).property('pkid', vertex_label.id). \ + property(property_name, value) + self.execute_traversal(write_traversal, graphson) + + #vertex = list(schema.add_vertex(self.session, vertex_label, property_name, value, execution_profile=ep))[0] + #vertex_properties = list(schema.get_vertex_properties( + # self.session, vertex, execution_profile=ep)) + + read_traversal = g.V().hasLabel(str(vertex_label.label)).has(property_name).properties() + vertex_properties = self.execute_traversal(read_traversal, graphson) + + self.assertEqual(len(vertex_properties), 2) # include pkid + for vp in vertex_properties: + if vp.label == 'pkid': + continue + + self.assertIsInstance(vp, (VertexProperty, TravVertexProperty)) + self.assertEqual(vp.label, property_name) + self.assertEqual(vp.value, value) + + @staticmethod + def fetch_edge_props(g, edge): + edge_props = g.E(edge.id).properties().toList() + return edge_props + + @staticmethod + def fetch_vertex_props(g, vertex): + + vertex_props = g.V(vertex.id).properties().toList() + return vertex_props + + def _check_equality(self, g, original, read_value): + return check_equality_base(self, original, read_value) + + +def _validate_prop(key, value, unittest): + if key == 'index': + return + + if any(key.startswith(t) for t in ('int', 'short')): + typ = int + + elif any(key.startswith(t) for t in ('long',)): + if sys.version_info >= (3, 0): + typ = int + else: + typ = long + elif any(key.startswith(t) for t in ('float', 'double')): + typ = float + elif any(key.startswith(t) for t in ('polygon',)): + typ = Polygon + elif any(key.startswith(t) for t in ('point',)): + typ = Point + elif any(key.startswith(t) for t in ('Linestring',)): + typ = LineString + elif any(key.startswith(t) for t in ('neg',)): + typ = str + elif any(key.startswith(t) for t in ('date',)): + typ = datetime.date + elif any(key.startswith(t) for t in ('time',)): + typ = datetime.time + else: + unittest.fail("Received unexpected type: %s" % key) + + +@requiredse +class BaseImplicitExecutionTest(GraphUnitTestCase): + """ + This test class will execute all tests of the AbstractTraversalTestClass using implicit execution + This all traversal will be run directly using toList() + """ + def setUp(self): + super(BaseImplicitExecutionTest, self).setUp() + if DSE_VERSION: + self.ep = DseGraph().create_execution_profile(self.graph_name) + self.cluster.add_execution_profile(self.graph_name, self.ep) + + @staticmethod + def fetch_key_from_prop(property): + return property.key + + def fetch_traversal_source(self, graphson, **kwargs): + ep = self.get_execution_profile(graphson, traversal=True) + return DseGraph().traversal_source(self.session, self.graph_name, execution_profile=ep, **kwargs) + + def execute_traversal(self, traversal, graphson=None): + return traversal.toList() + + def _validate_classic_vertex(self, g, vertex): + # Checks the properties on a classic vertex for correctness + vertex_props = self.fetch_vertex_props(g, vertex) + vertex_prop_keys = [vp.key for vp in vertex_props] + self.assertEqual(len(vertex_prop_keys), 2) + self.assertIn('name', vertex_prop_keys) + self.assertTrue('lang' in vertex_prop_keys or 'age' in vertex_prop_keys) + + def _validate_generic_vertex_result_type(self, g, vertex): + # Checks a vertex object for it's generic properties + properties = self.fetch_vertex_props(g, vertex) + for attr in ('id', 'label'): + self.assertIsNotNone(getattr(vertex, attr)) + self.assertTrue(len(properties) > 2) + + def _validate_classic_edge_properties(self, g, edge): + # Checks the properties on a classic edge for correctness + edge_props = self.fetch_edge_props(g, edge) + edge_prop_keys = [ep.key for ep in edge_props] + self.assertEqual(len(edge_prop_keys), 1) + self.assertIn('weight', edge_prop_keys) + + def _validate_classic_edge(self, g, edge): + self._validate_generic_edge_result_type(edge) + self._validate_classic_edge_properties(g, edge) + + def _validate_line_edge(self, g, edge): + self._validate_generic_edge_result_type(edge) + edge_props = self.fetch_edge_props(g, edge) + edge_prop_keys = [ep.key for ep in edge_props] + self.assertEqual(len(edge_prop_keys), 1) + self.assertIn('distance', edge_prop_keys) + + def _validate_generic_edge_result_type(self, edge): + self.assertIsInstance(edge, TravEdge) + + for attr in ('outV', 'inV', 'label', 'id'): + self.assertIsNotNone(getattr(edge, attr)) + + def _validate_path_result_type(self, g, objects_path): + for obj in objects_path: + if isinstance(obj, TravEdge): + self._validate_classic_edge(g, obj) + elif isinstance(obj, TravVertex): + self._validate_classic_vertex(g, obj) + else: + self.fail("Invalid object found in path " + str(obj.type)) + + def _validate_meta_property(self, g, vertex): + meta_props = g.V(vertex.id).properties().toList() + self.assertEqual(len(meta_props), 1) + meta_prop = meta_props[0] + self.assertEqual(meta_prop.value, "meta_prop") + self.assertEqual(meta_prop.key, "key") + + nested_props = g.V(vertex.id).properties().properties().toList() + self.assertEqual(len(nested_props), 2) + for nested_prop in nested_props: + self.assertTrue(nested_prop.key in ['k0', 'k1']) + self.assertTrue(nested_prop.value in ['v0', 'v1']) + + def _validate_type(self, g, vertex): + props = self.fetch_vertex_props(g, vertex) + for prop in props: + value = prop.value + key = prop.key + _validate_prop(key, value, self) + + +class BaseExplicitExecutionTest(GraphUnitTestCase): + + def fetch_traversal_source(self, graphson, **kwargs): + ep = self.get_execution_profile(graphson, traversal=True) + return DseGraph().traversal_source(self.session, self.graph_name, execution_profile=ep, **kwargs) + + def execute_traversal(self, traversal, graphson): + ep = self.get_execution_profile(graphson, traversal=True) + ep = self.session.get_execution_profile(ep) + context = None + if graphson == GraphProtocol.GRAPHSON_3_0: + context = { + 'cluster': self.cluster, + 'graph_name': ep.graph_options.graph_name.decode('utf-8') if ep.graph_options.graph_name else None + } + query = DseGraph.query_from_traversal(traversal, graphson, context=context) + # Use an ep that is configured with the correct row factory, and bytecode-json language flat set + result_set = self.execute_graph(query, graphson, traversal=True) + return list(result_set) diff --git a/tests/integration/advanced/graph/fluent/test_graph.py b/tests/integration/advanced/graph/fluent/test_graph.py new file mode 100644 index 0000000000..a2c01affb3 --- /dev/null +++ b/tests/integration/advanced/graph/fluent/test_graph.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra import cluster +from cassandra.cluster import ContinuousPagingOptions +from cassandra.datastax.graph.fluent import DseGraph +from cassandra.graph import VertexProperty + +from tests.integration import greaterthanorequaldse68 +from tests.integration.advanced.graph import ( + GraphUnitTestCase, ClassicGraphSchema, CoreGraphSchema, + VertexLabel, GraphTestConfiguration +) +from tests.integration import greaterthanorequaldse60 +from tests.integration.advanced.graph.fluent import ( + BaseExplicitExecutionTest, create_traversal_profiles, check_equality_base) + +import unittest + + +@greaterthanorequaldse60 +@GraphTestConfiguration.generate_tests(traversal=True) +class BatchStatementTests(BaseExplicitExecutionTest): + + def setUp(self): + super(BatchStatementTests, self).setUp() + self.ep_graphson2, self.ep_graphson3 = create_traversal_profiles(self.cluster, self.graph_name) + + def _test_batch_with_schema(self, schema, graphson): + """ + Sends a Batch statement and verifies it has succeeded with a schema created + + @since 1.1.0 + @jira_ticket PYTHON-789 + @expected_result ValueError is arisen + + @test_category dse graph + """ + self._send_batch_and_read_results(schema, graphson) + + def _test_batch_without_schema(self, schema, graphson): + """ + Sends a Batch statement and verifies it has succeeded without a schema created + + @since 1.1.0 + @jira_ticket PYTHON-789 + @expected_result ValueError is arisen + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('schema-less is only for classic graphs') + self._send_batch_and_read_results(schema, graphson, use_schema=False) + + def _test_batch_with_schema_add_all(self, schema, graphson): + """ + Sends a Batch statement and verifies it has succeeded with a schema created. + Uses :method:`dse_graph.query._BatchGraphStatement.add_all` to add the statements + instead of :method:`dse_graph.query._BatchGraphStatement.add` + + @since 1.1.0 + @jira_ticket PYTHON-789 + @expected_result ValueError is arisen + + @test_category dse graph + """ + self._send_batch_and_read_results(schema, graphson, add_all=True) + + def _test_batch_without_schema_add_all(self, schema, graphson): + """ + Sends a Batch statement and verifies it has succeeded without a schema created + Uses :method:`dse_graph.query._BatchGraphStatement.add_all` to add the statements + instead of :method:`dse_graph.query._BatchGraphStatement.add` + + @since 1.1.0 + @jira_ticket PYTHON-789 + @expected_result ValueError is arisen + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('schema-less is only for classic graphs') + self._send_batch_and_read_results(schema, graphson, add_all=True, use_schema=False) + + def test_only_graph_traversals_are_accepted(self): + """ + Verifies that ValueError is risen if the parameter add is not a traversal + + @since 1.1.0 + @jira_ticket PYTHON-789 + @expected_result ValueError is arisen + + @test_category dse graph + """ + batch = DseGraph.batch() + self.assertRaises(ValueError, batch.add, '{"@value":{"step":[["addV","poc_int"],' + '["property","bigint1value",{"@value":12,"@type":"g:Int32"}]]},' + '"@type":"g:Bytecode"}') + another_batch = DseGraph.batch() + self.assertRaises(ValueError, batch.add, another_batch) + + def _send_batch_and_read_results(self, schema, graphson, add_all=False, use_schema=True): + traversals = [] + datatypes = schema.fixtures.datatypes() + values = {} + g = self.fetch_traversal_source(graphson) + ep = self.get_execution_profile(graphson) + batch = DseGraph.batch(session=self.session, + execution_profile=self.get_execution_profile(graphson, traversal=True)) + for data in datatypes.values(): + typ, value, deserializer = data + vertex_label = VertexLabel([typ]) + property_name = next(iter(vertex_label.non_pk_properties.keys())) + values[property_name] = value + if use_schema or schema is CoreGraphSchema: + schema.create_vertex_label(self.session, vertex_label, execution_profile=ep) + + traversal = g.addV(str(vertex_label.label)).property('pkid', vertex_label.id).property(property_name, value) + if not add_all: + batch.add(traversal) + traversals.append(traversal) + + if add_all: + batch.add_all(traversals) + + self.assertEqual(len(datatypes), len(batch)) + + batch.execute() + + vertices = self.execute_traversal(g.V(), graphson) + self.assertEqual(len(vertices), len(datatypes), "g.V() returned {}".format(vertices)) + + # Iterate over all the vertices and check that they match the original input + for vertex in vertices: + schema.ensure_properties(self.session, vertex, execution_profile=ep) + key = [k for k in list(vertex.properties.keys()) if k != 'pkid'][0].replace("value", "") + original = values[key] + self._check_equality(original, vertex) + + def _check_equality(self, original, vertex): + for key in vertex.properties: + if key == 'pkid': + continue + value = vertex.properties[key].value \ + if isinstance(vertex.properties[key], VertexProperty) else vertex.properties[key][0].value + check_equality_base(self, original, value) + + +class ContinuousPagingOptionsForTests(ContinuousPagingOptions): + def __init__(self, + page_unit=ContinuousPagingOptions.PagingUnit.ROWS, max_pages=1, # max_pages=1 + max_pages_per_second=0, max_queue_size=4): + super(ContinuousPagingOptionsForTests, self).__init__(page_unit, max_pages, max_pages_per_second, + max_queue_size) + + +def reset_paging_options(): + cluster.ContinuousPagingOptions = ContinuousPagingOptions + + +@greaterthanorequaldse68 +@GraphTestConfiguration.generate_tests(schema=CoreGraphSchema) +class GraphPagingTest(GraphUnitTestCase): + + def setUp(self): + super(GraphPagingTest, self).setUp() + self.addCleanup(reset_paging_options) + self.ep_graphson2, self.ep_graphson3 = create_traversal_profiles(self.cluster, self.graph_name) + + def _setup_data(self, schema, graphson): + self.execute_graph( + "schema.vertexLabel('person').ifNotExists().partitionBy('name', Text).property('age', Int).create();", + graphson) + for i in range(100): + self.execute_graph("g.addV('person').property('name', 'batman-{}')".format(i), graphson) + + def _test_cont_paging_is_enabled_by_default(self, schema, graphson): + """ + Test that graph paging is automatically enabled with a >=6.8 cluster. + + @jira_ticket PYTHON-1045 + @expected_result the default continuous paging options are used + + @test_category dse graph + """ + # with traversals... I don't have access to the response future... so this is a hack to ensure paging is on + cluster.ContinuousPagingOptions = ContinuousPagingOptionsForTests + ep = self.get_execution_profile(graphson, traversal=True) + self._setup_data(schema, graphson) + self.session.default_fetch_size = 10 + g = DseGraph.traversal_source(self.session, execution_profile=ep) + results = g.V().toList() + self.assertEqual(len(results), 10) # only 10 results due to our hack + + def _test_cont_paging_can_be_disabled(self, schema, graphson): + """ + Test that graph paging can be disabled. + + @jira_ticket PYTHON-1045 + @expected_result the default continuous paging options are not used + + @test_category dse graph + """ + # with traversals... I don't have access to the response future... so this is a hack to ensure paging is on + cluster.ContinuousPagingOptions = ContinuousPagingOptionsForTests + ep = self.get_execution_profile(graphson, traversal=True) + ep = self.session.execution_profile_clone_update(ep, continuous_paging_options=None) + self._setup_data(schema, graphson) + self.session.default_fetch_size = 10 + g = DseGraph.traversal_source(self.session, execution_profile=ep) + results = g.V().toList() + self.assertEqual(len(results), 100) # 100 results since paging is disabled + + def _test_cont_paging_with_custom_options(self, schema, graphson): + """ + Test that we can specify custom paging options. + + @jira_ticket PYTHON-1045 + @expected_result we get only the desired number of results + + @test_category dse graph + """ + ep = self.get_execution_profile(graphson, traversal=True) + ep = self.session.execution_profile_clone_update(ep, + continuous_paging_options=ContinuousPagingOptions(max_pages=1)) + self._setup_data(schema, graphson) + self.session.default_fetch_size = 10 + g = DseGraph.traversal_source(self.session, execution_profile=ep) + results = g.V().toList() + self.assertEqual(len(results), 10) # only 10 results since paging is disabled diff --git a/tests/integration/advanced/graph/fluent/test_graph_explicit_execution.py b/tests/integration/advanced/graph/fluent/test_graph_explicit_execution.py new file mode 100644 index 0000000000..a5dd4306c5 --- /dev/null +++ b/tests/integration/advanced/graph/fluent/test_graph_explicit_execution.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.graph import Vertex, Edge + +from tests.integration.advanced.graph import ( + validate_classic_vertex, validate_classic_edge, validate_generic_vertex_result_type, + validate_classic_edge_properties, validate_line_edge, + validate_generic_edge_result_type, validate_path_result_type) + +from tests.integration import requiredse, DSE_VERSION +from tests.integration.advanced import use_single_node_with_graph +from tests.integration.advanced.graph import GraphTestConfiguration +from tests.integration.advanced.graph.fluent import ( + BaseExplicitExecutionTest, _AbstractTraversalTest, _validate_prop) + + +def setup_module(): + if DSE_VERSION: + dse_options = {'graph': {'realtime_evaluation_timeout_in_seconds': 60}} + use_single_node_with_graph(dse_options=dse_options) + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True) +class ExplicitExecutionTest(BaseExplicitExecutionTest, _AbstractTraversalTest): + """ + This test class will execute all tests of the AbstractTraversalTestClass using Explicit execution + All queries will be run by converting them to byte code, and calling execute graph explicitly with a generated ep. + """ + @staticmethod + def fetch_key_from_prop(property): + return property.label + + def _validate_classic_vertex(self, g, vertex): + validate_classic_vertex(self, vertex) + + def _validate_generic_vertex_result_type(self, g, vertex): + validate_generic_vertex_result_type(self, vertex) + + def _validate_classic_edge_properties(self, g, edge): + validate_classic_edge_properties(self, edge) + + def _validate_classic_edge(self, g, edge): + validate_classic_edge(self, edge) + + def _validate_line_edge(self, g, edge): + validate_line_edge(self, edge) + + def _validate_generic_edge_result_type(self, edge): + validate_generic_edge_result_type(self, edge) + + def _validate_type(self, g, vertex): + for key in vertex.properties: + value = vertex.properties[key][0].value + _validate_prop(key, value, self) + + def _validate_path_result_type(self, g, path_obj): + # This pre-processing is due to a change in TinkerPop + # properties are not returned automatically anymore + # with some queries. + for obj in path_obj.objects: + if not obj.properties: + props = [] + if isinstance(obj, Edge): + obj.properties = { + p.key: p.value + for p in self.fetch_edge_props(g, obj) + } + elif isinstance(obj, Vertex): + obj.properties = { + p.label: p.value + for p in self.fetch_vertex_props(g, obj) + } + + validate_path_result_type(self, path_obj) + + def _validate_meta_property(self, g, vertex): + + self.assertEqual(len(vertex.properties), 1) + self.assertEqual(len(vertex.properties['key']), 1) + p = vertex.properties['key'][0] + self.assertEqual(p.label, 'key') + self.assertEqual(p.value, 'meta_prop') + self.assertEqual(p.properties, {'k0': 'v0', 'k1': 'v1'}) diff --git a/tests/integration/advanced/graph/fluent/test_graph_implicit_execution.py b/tests/integration/advanced/graph/fluent/test_graph_implicit_execution.py new file mode 100644 index 0000000000..1407dd1ea3 --- /dev/null +++ b/tests/integration/advanced/graph/fluent/test_graph_implicit_execution.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import Future +from cassandra.datastax.graph.fluent import DseGraph + +from tests.integration import requiredse, DSE_VERSION +from tests.integration.advanced import use_single_node_with_graph +from tests.integration.advanced.graph import GraphTestConfiguration +from tests.integration.advanced.graph.fluent import ( + BaseImplicitExecutionTest, create_traversal_profiles, _AbstractTraversalTest) + + +def setup_module(): + if DSE_VERSION: + dse_options = {'graph': {'realtime_evaluation_timeout_in_seconds': 60}} + use_single_node_with_graph(dse_options=dse_options) + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True) +class ImplicitExecutionTest(BaseImplicitExecutionTest, _AbstractTraversalTest): + def _test_iterate_step(self, schema, graphson): + """ + Test to validate that the iterate() step work on all dse versions. + @jira_ticket PYTHON-1155 + @expected_result iterate step works + @test_category dse graph + """ + + g = self.fetch_traversal_source(graphson) + self.execute_graph(schema.fixtures.classic(), graphson) + g.addV('person').property('name', 'Person1').iterate() + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True) +class ImplicitAsyncExecutionTest(BaseImplicitExecutionTest): + """ + Test to validate that the traversal async execution works properly. + + @since 3.21.0 + @jira_ticket PYTHON-1129 + + @test_category dse graph + """ + + def setUp(self): + super(ImplicitAsyncExecutionTest, self).setUp() + self.ep_graphson2, self.ep_graphson3 = create_traversal_profiles(self.cluster, self.graph_name) + + def _validate_results(self, results): + results = list(results) + self.assertEqual(len(results), 2) + self.assertIn('vadas', results) + self.assertIn('josh', results) + + def _test_promise(self, schema, graphson): + self.execute_graph(schema.fixtures.classic(), graphson) + g = self.fetch_traversal_source(graphson) + traversal_future = g.V().has('name', 'marko').out('knows').values('name').promise() + self._validate_results(traversal_future.result()) + + def _test_promise_error_is_propagated(self, schema, graphson): + self.execute_graph(schema.fixtures.classic(), graphson) + g = DseGraph().traversal_source(self.session, 'wrong_graph', execution_profile=self.ep) + traversal_future = g.V().has('name', 'marko').out('knows').values('name').promise() + with self.assertRaises(Exception): + traversal_future.result() + + def _test_promise_callback(self, schema, graphson): + self.execute_graph(schema.fixtures.classic(), graphson) + g = self.fetch_traversal_source(graphson) + future = Future() + + def cb(f): + future.set_result(f.result()) + + traversal_future = g.V().has('name', 'marko').out('knows').values('name').promise() + traversal_future.add_done_callback(cb) + self._validate_results(future.result()) + + def _test_promise_callback_on_error(self, schema, graphson): + self.execute_graph(schema.fixtures.classic(), graphson) + g = DseGraph().traversal_source(self.session, 'wrong_graph', execution_profile=self.ep) + future = Future() + + def cb(f): + try: + f.result() + except Exception as e: + future.set_exception(e) + + traversal_future = g.V().has('name', 'marko').out('knows').values('name').promise() + traversal_future.add_done_callback(cb) + with self.assertRaises(Exception): + future.result() diff --git a/tests/integration/advanced/graph/fluent/test_search.py b/tests/integration/advanced/graph/fluent/test_search.py new file mode 100644 index 0000000000..b6857f3560 --- /dev/null +++ b/tests/integration/advanced/graph/fluent/test_search.py @@ -0,0 +1,541 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.util import Distance +from cassandra import InvalidRequest +from cassandra.graph import GraphProtocol +from cassandra.datastax.graph.fluent import DseGraph +from cassandra.datastax.graph.fluent.predicates import Search, Geo, GeoUnit, CqlCollection + +from tests.integration.advanced import use_single_node_with_graph_and_solr +from tests.integration.advanced.graph import GraphUnitTestCase, CoreGraphSchema, ClassicGraphSchema, GraphTestConfiguration +from tests.integration import greaterthanorequaldse51, DSE_VERSION, requiredse + + +def setup_module(): + if DSE_VERSION: + use_single_node_with_graph_and_solr() + + +class AbstractSearchTest(GraphUnitTestCase): + + def setUp(self): + super(AbstractSearchTest, self).setUp() + self.ep_graphson2 = DseGraph().create_execution_profile(self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_2_0) + self.ep_graphson3 = DseGraph().create_execution_profile(self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_3_0) + + self.cluster.add_execution_profile('traversal_graphson2', self.ep_graphson2) + self.cluster.add_execution_profile('traversal_graphson3', self.ep_graphson3) + + def fetch_traversal_source(self, graphson): + ep = self.get_execution_profile(graphson, traversal=True) + return DseGraph().traversal_source(self.session, self.graph_name, execution_profile=ep) + + def _test_search_by_prefix(self, schema, graphson): + """ + Test to validate that solr searches by prefix function. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names starting with Paul should be returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "name", Search.prefix("Paul")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertEqual(results_list[0], "Paul Thomas Joe") + + def _test_search_by_regex(self, schema, graphson): + """ + Test to validate that solr searches by regex function. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names containing Paul should be returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "name", Search.regex(".*Paul.*")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn("Paul Thomas Joe", results_list) + self.assertIn("James Paul Smith", results_list) + + def _test_search_by_token(self, schema, graphson): + """ + Test to validate that solr searches by token. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names with description containing could shoud be returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "description", Search.token("cold")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn("Jill Alice", results_list) + self.assertIn("George Bill Steve", results_list) + + def _test_search_by_token_prefix(self, schema, graphson): + """ + Test to validate that solr searches by token prefix. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names with description containing a token starting with h are returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "description", Search.token_prefix("h")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn("Paul Thomas Joe", results_list) + self.assertIn( "James Paul Smith", results_list) + + def _test_search_by_token_regex(self, schema, graphson): + """ + Test to validate that solr searches by token regex. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names with description containing nice or hospital are returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "description", Search.token_regex("(nice|hospital)")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn("Paul Thomas Joe", results_list ) + self.assertIn( "Jill Alice", results_list ) + + def _assert_in_distance(self, schema, graphson, inside, names): + """ + Helper function that asserts that an exception is arisen if geodetic predicates are used + in cartesian geometry. Also asserts that the expected list is equal to the returned from + the transversal using different search indexes. + """ + def assert_equal_list(L1, L2): + return len(L1) == len(L2) and sorted(L1) == sorted(L2) + + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + + traversal = g.V().has("person", "pointPropWithBoundsWithSearchIndex", inside).values("name") + if schema is ClassicGraphSchema: + # throws an exception because of a SOLR/Search limitation in the indexing process + # may be resolved in the future + self.assertRaises(InvalidRequest, self.execute_traversal, traversal, graphson) + else: + traversal = g.V().has("person", "pointPropWithBoundsWithSearchIndex", inside).values("name") + results_list = self.execute_traversal(traversal, graphson) + assert_equal_list(names, results_list) + + traversal = g.V().has("person", "pointPropWithBounds", inside).values("name") + results_list = self.execute_traversal(traversal, graphson) + assert_equal_list(names, results_list) + + traversal = g.V().has("person", "pointPropWithGeoBoundsWithSearchIndex", inside).values("name") + results_list = self.execute_traversal(traversal, graphson) + assert_equal_list(names, results_list) + + traversal = g.V().has("person", "pointPropWithGeoBounds", inside).values("name") + results_list = self.execute_traversal(traversal, graphson) + assert_equal_list(names, results_list) + + @greaterthanorequaldse51 + def _test_search_by_distance(self, schema, graphson): + """ + Test to validate that solr searches by distance. + + @since 1.0.0 + @jira_ticket PYTHON-660 + @expected_result all names with a geo location within a 2 degree distance of -92,44 are returned + + @test_category dse graph + """ + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92, 44, 2)), + ["Paul Thomas Joe", "George Bill Steve"] + ) + + @greaterthanorequaldse51 + def _test_search_by_distance_meters_units(self, schema, graphson): + """ + Test to validate that solr searches by distance. + + @since 2.0.0 + @jira_ticket PYTHON-698 + @expected_result all names with a geo location within a 56k-meter radius of -92,44 are returned + + @test_category dse graph + """ + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92, 44, 56000), GeoUnit.METERS), + ["Paul Thomas Joe"] + ) + + @greaterthanorequaldse51 + def _test_search_by_distance_miles_units(self, schema, graphson): + """ + Test to validate that solr searches by distance. + + @since 2.0.0 + @jira_ticket PYTHON-698 + @expected_result all names with a geo location within a 70-mile radius of -92,44 are returned + + @test_category dse graph + """ + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92, 44, 70), GeoUnit.MILES), + ["Paul Thomas Joe", "George Bill Steve"] + ) + + @greaterthanorequaldse51 + def _test_search_by_distance_check_limit(self, schema, graphson): + """ + Test to validate that solr searches by distance using several units. It will also validate + that and exception is arisen if geodetic predicates are used against cartesian geometry + + @since 2.0.0 + @jira_ticket PYTHON-698 + @expected_result if the search distance is below the real distance only one + name will be in the list, otherwise, two + + @test_category dse graph + """ + # Paul Thomas Joe and George Bill Steve are 64.6923761881464 km apart + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92.46295, 44.0234, 65), GeoUnit.KILOMETERS), + ["George Bill Steve", "Paul Thomas Joe"] + ) + + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92.46295, 44.0234, 64), GeoUnit.KILOMETERS), + ["Paul Thomas Joe"] + ) + + # Paul Thomas Joe and George Bill Steve are 40.19797892069464 miles apart + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92.46295, 44.0234, 41), GeoUnit.MILES), + ["George Bill Steve", "Paul Thomas Joe"] + ) + + self._assert_in_distance(schema, graphson, + Geo.inside(Distance(-92.46295, 44.0234, 40), GeoUnit.MILES), + ["Paul Thomas Joe"] + ) + + @greaterthanorequaldse51 + def _test_search_by_fuzzy(self, schema, graphson): + """ + Test to validate that solr searches by distance. + + @since 1.0.0 + @jira_ticket PYTHON-664 + @expected_result all names with a geo location within a 2 radius distance of -92,44 are returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "name", Search.fuzzy("Paul Thamas Joe", 1)).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("Paul Thomas Joe", results_list) + + traversal = g.V().has("person", "name", Search.fuzzy("Paul Thames Joe", 1)).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 0) + + @greaterthanorequaldse51 + def _test_search_by_fuzzy_token(self, schema, graphson): + """ + Test to validate that fuzzy searches. + + @since 1.0.0 + @jira_ticket PYTHON-664 + @expected_result all names with that differ from the search criteria by one letter should be returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "description", Search.token_fuzzy("lives", 1)).values("name") + # Should match 'Paul Thomas Joe' since description contains 'Lives' + # Should match 'James Paul Joe' since description contains 'Likes' + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn("Paul Thomas Joe", results_list) + self.assertIn("James Paul Smith", results_list) + + traversal = g.V().has("person", "description", Search.token_fuzzy("loues", 1)).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 0) + + @greaterthanorequaldse51 + def _test_search_by_phrase(self, schema, graphson): + """ + Test to validate that phrase searches. + + @since 1.0.0 + @jira_ticket PYTHON-664 + @expected_result all names with that differ from the search phrase criteria by two letter should be returned + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.address_book(), graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.V().has("person", "description", Search.phrase("a cold", 2)).values("name") + #Should match 'George Bill Steve' since 'A cold dude' is at distance of 0 for 'a cold'. + #Should match 'Jill Alice' since 'Enjoys a very nice cold coca cola' is at distance of 2 for 'a cold'. + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 2) + self.assertIn('George Bill Steve', results_list) + self.assertIn('Jill Alice', results_list) + + traversal = g.V().has("person", "description", Search.phrase("a bald", 2)).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 0) + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True) +class ImplicitSearchTest(AbstractSearchTest): + """ + This test class will execute all tests of the AbstractSearchTest using implicit execution + All traversals will be run directly using toList() + """ + def fetch_key_from_prop(self, property): + return property.key + + def execute_traversal(self, traversal, graphson=None): + return traversal.toList() + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True) +class ExplicitSearchTest(AbstractSearchTest): + """ + This test class will execute all tests of the AbstractSearchTest using implicit execution + All traversals will be converted to byte code then they will be executed explicitly. + """ + + def execute_traversal(self, traversal, graphson): + ep = self.get_execution_profile(graphson, traversal=True) + ep = self.session.get_execution_profile(ep) + context = None + if graphson == GraphProtocol.GRAPHSON_3_0: + context = { + 'cluster': self.cluster, + 'graph_name': ep.graph_options.graph_name.decode('utf-8') if ep.graph_options.graph_name else None + } + query = DseGraph.query_from_traversal(traversal, graphson, context=context) + #Use an ep that is configured with the correct row factory, and bytecode-json language flat set + result_set = self.execute_graph(query, graphson, traversal=True) + return list(result_set) + + +@requiredse +class BaseCqlCollectionPredicatesTest(GraphUnitTestCase): + + def setUp(self): + super(BaseCqlCollectionPredicatesTest, self).setUp() + self.ep_graphson3 = DseGraph().create_execution_profile(self.graph_name, + graph_protocol=GraphProtocol.GRAPHSON_3_0) + self.cluster.add_execution_profile('traversal_graphson3', self.ep_graphson3) + + def fetch_traversal_source(self, graphson): + ep = self.get_execution_profile(graphson, traversal=True) + return DseGraph().traversal_source(self.session, self.graph_name, execution_profile=ep) + + def setup_vertex_label(self, graphson): + ep = self.get_execution_profile(graphson) + self.session.execute_graph(""" + schema.vertexLabel('cqlcollections').ifNotExists().partitionBy('name', Varchar) + .property('list', listOf(Text)) + .property('frozen_list', frozen(listOf(Text))) + .property('set', setOf(Text)) + .property('frozen_set', frozen(setOf(Text))) + .property('map_keys', mapOf(Int, Text)) + .property('map_values', mapOf(Int, Text)) + .property('map_entries', mapOf(Int, Text)) + .property('frozen_map', frozen(mapOf(Int, Text))) + .create() + """, execution_profile=ep) + + self.session.execute_graph(""" + schema.vertexLabel('cqlcollections').secondaryIndex('list').by('list').create(); + schema.vertexLabel('cqlcollections').secondaryIndex('frozen_list').by('frozen_list').indexFull().create(); + schema.vertexLabel('cqlcollections').secondaryIndex('set').by('set').create(); + schema.vertexLabel('cqlcollections').secondaryIndex('frozen_set').by('frozen_set').indexFull().create(); + schema.vertexLabel('cqlcollections').secondaryIndex('map_keys').by('map_keys').indexKeys().create(); + schema.vertexLabel('cqlcollections').secondaryIndex('map_values').by('map_values').indexValues().create(); + schema.vertexLabel('cqlcollections').secondaryIndex('map_entries').by('map_entries').indexEntries().create(); + schema.vertexLabel('cqlcollections').secondaryIndex('frozen_map').by('frozen_map').indexFull().create(); + """, execution_profile=ep) + + def _test_contains_list(self, schema, graphson): + """ + Test to validate that the cql predicate contains works with list + + @since TODO dse 6.8 + @jira_ticket PYTHON-1039 + @expected_result contains predicate work on a list + + @test_category dse graph + """ + self.setup_vertex_label(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.addV("cqlcollections").property("name", "list1").property("list", ['item1', 'item2']) + self.execute_traversal(traversal, graphson) + traversal = g.addV("cqlcollections").property("name", "list2").property("list", ['item3', 'item4']) + self.execute_traversal(traversal, graphson) + traversal = g.V().has("cqlcollections", "list", CqlCollection.contains("item1")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("list1", results_list) + + def _test_contains_set(self, schema, graphson): + """ + Test to validate that the cql predicate contains works with set + + @since TODO dse 6.8 + @jira_ticket PYTHON-1039 + @expected_result contains predicate work on a set + + @test_category dse graph + """ + self.setup_vertex_label(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.addV("cqlcollections").property("name", "set1").property("set", {'item1', 'item2'}) + self.execute_traversal(traversal, graphson) + traversal = g.addV("cqlcollections").property("name", "set2").property("set", {'item3', 'item4'}) + self.execute_traversal(traversal, graphson) + traversal = g.V().has("cqlcollections", "set", CqlCollection.contains("item1")).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("set1", results_list) + + def _test_contains_key_map(self, schema, graphson): + """ + Test to validate that the cql predicate contains_key works with map + + @since TODO dse 6.8 + @jira_ticket PYTHON-1039 + @expected_result contains_key predicate work on a map + + @test_category dse graph + """ + self.setup_vertex_label(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.addV("cqlcollections").property("name", "map1").property("map_keys", {0: 'item1', 1: 'item2'}) + self.execute_traversal(traversal, graphson) + traversal = g.addV("cqlcollections").property("name", "map2").property("map_keys", {2: 'item3', 3: 'item4'}) + self.execute_traversal(traversal, graphson) + traversal = g.V().has("cqlcollections", "map_keys", CqlCollection.contains_key(0)).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("map1", results_list) + + def _test_contains_value_map(self, schema, graphson): + """ + Test to validate that the cql predicate contains_value works with map + + @since TODO dse 6.8 + @jira_ticket PYTHON-1039 + @expected_result contains_value predicate work on a map + + @test_category dse graph + """ + self.setup_vertex_label(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.addV("cqlcollections").property("name", "map1").property("map_values", {0: 'item1', 1: 'item2'}) + self.execute_traversal(traversal, graphson) + traversal = g.addV("cqlcollections").property("name", "map2").property("map_values", {2: 'item3', 3: 'item4'}) + self.execute_traversal(traversal, graphson) + traversal = g.V().has("cqlcollections", "map_values", CqlCollection.contains_value('item3')).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("map2", results_list) + + def _test_entry_eq_map(self, schema, graphson): + """ + Test to validate that the cql predicate entry_eq works with map + + @since TODO dse 6.8 + @jira_ticket PYTHON-1039 + @expected_result entry_eq predicate work on a map + + @test_category dse graph + """ + self.setup_vertex_label(graphson) + g = self.fetch_traversal_source(graphson) + traversal = g.addV("cqlcollections").property("name", "map1").property("map_entries", {0: 'item1', 1: 'item2'}) + self.execute_traversal(traversal, graphson) + traversal = g.addV("cqlcollections").property("name", "map2").property("map_entries", {2: 'item3', 3: 'item4'}) + self.execute_traversal(traversal, graphson) + traversal = g.V().has("cqlcollections", "map_entries", CqlCollection.entry_eq([2, 'item3'])).values("name") + results_list = self.execute_traversal(traversal, graphson) + self.assertEqual(len(results_list), 1) + self.assertIn("map2", results_list) + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True, schema=CoreGraphSchema) +class ImplicitCqlCollectionPredicatesTest(BaseCqlCollectionPredicatesTest): + """ + This test class will execute all tests of the BaseCqlCollectionTest using implicit execution + All traversals will be run directly using toList() + """ + + def execute_traversal(self, traversal, graphson=None): + return traversal.toList() + + +@requiredse +@GraphTestConfiguration.generate_tests(traversal=True, schema=CoreGraphSchema) +class ExplicitCqlCollectionPredicatesTest(BaseCqlCollectionPredicatesTest): + """ + This test class will execute all tests of the AbstractSearchTest using implicit execution + All traversals will be converted to byte code then they will be executed explicitly. + """ + + def execute_traversal(self, traversal, graphson): + ep = self.get_execution_profile(graphson, traversal=True) + ep = self.session.get_execution_profile(ep) + context = None + if graphson == GraphProtocol.GRAPHSON_3_0: + context = { + 'cluster': self.cluster, + 'graph_name': ep.graph_options.graph_name.decode('utf-8') if ep.graph_options.graph_name else None + } + query = DseGraph.query_from_traversal(traversal, graphson, context=context) + result_set = self.execute_graph(query, graphson, traversal=True) + return list(result_set) diff --git a/tests/integration/advanced/graph/test_graph.py b/tests/integration/advanced/graph/test_graph.py new file mode 100644 index 0000000000..3624b5e1ef --- /dev/null +++ b/tests/integration/advanced/graph/test_graph.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from cassandra import OperationTimedOut, InvalidRequest +from cassandra.protocol import SyntaxException +from cassandra.policies import WhiteListRoundRobinPolicy +from cassandra.cluster import NoHostAvailable +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT, GraphExecutionProfile +from cassandra.graph import single_object_row_factory, Vertex, graph_object_row_factory, \ + graph_graphson2_row_factory, graph_graphson3_row_factory +from cassandra.util import SortedSet + +from tests.integration import DSE_VERSION, greaterthanorequaldse51, greaterthanorequaldse68, \ + requiredse, TestCluster +from tests.integration.advanced.graph import BasicGraphUnitTestCase, GraphUnitTestCase, \ + GraphProtocol, ClassicGraphSchema, CoreGraphSchema, use_single_node_with_graph + + +def setup_module(): + if DSE_VERSION: + dse_options = {'graph': {'realtime_evaluation_timeout_in_seconds': 60}} + use_single_node_with_graph(dse_options=dse_options) + + +@requiredse +class GraphTimeoutTests(BasicGraphUnitTestCase): + + def test_should_wait_indefinitely_by_default(self): + """ + Tests that by default the client should wait indefinitely for server timeouts + + @since 1.0.0 + @jira_ticket PYTHON-589 + + @test_category dse graph + """ + desired_timeout = 1000 + + graph_source = "test_timeout_1" + ep_name = graph_source + ep = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT) + ep.graph_options = ep.graph_options.copy() + ep.graph_options.graph_source = graph_source + self.cluster.add_execution_profile(ep_name, ep) + + to_run = '''graph.schema().config().option("graph.traversal_sources.{0}.evaluation_timeout").set('{1} ms')'''.format( + graph_source, desired_timeout) + self.session.execute_graph(to_run, execution_profile=ep_name) + with self.assertRaises(InvalidRequest) as ir: + self.session.execute_graph("java.util.concurrent.TimeUnit.MILLISECONDS.sleep(35000L);1+1", + execution_profile=ep_name) + self.assertTrue("evaluation exceeded the configured threshold of 1000" in str(ir.exception) or + "evaluation exceeded the configured threshold of evaluation_timeout at 1000" in str( + ir.exception)) + + def test_request_timeout_less_then_server(self): + """ + Tests that with explicit request_timeouts set, that a server timeout is honored if it's relieved prior to the + client timeout + + @since 1.0.0 + @jira_ticket PYTHON-589 + + @test_category dse graph + """ + desired_timeout = 1000 + graph_source = "test_timeout_2" + ep_name = graph_source + ep = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, request_timeout=32) + ep.graph_options = ep.graph_options.copy() + ep.graph_options.graph_source = graph_source + self.cluster.add_execution_profile(ep_name, ep) + + to_run = '''graph.schema().config().option("graph.traversal_sources.{0}.evaluation_timeout").set('{1} ms')'''.format( + graph_source, desired_timeout) + self.session.execute_graph(to_run, execution_profile=ep_name) + with self.assertRaises(InvalidRequest) as ir: + self.session.execute_graph("java.util.concurrent.TimeUnit.MILLISECONDS.sleep(35000L);1+1", + execution_profile=ep_name) + self.assertTrue("evaluation exceeded the configured threshold of 1000" in str(ir.exception) or + "evaluation exceeded the configured threshold of evaluation_timeout at 1000" in str( + ir.exception)) + + def test_server_timeout_less_then_request(self): + """ + Tests that with explicit request_timeouts set, that a client timeout is honored if it's triggered prior to the + server sending a timeout. + + @since 1.0.0 + @jira_ticket PYTHON-589 + + @test_category dse graph + """ + graph_source = "test_timeout_3" + ep_name = graph_source + ep = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, request_timeout=1) + ep.graph_options = ep.graph_options.copy() + ep.graph_options.graph_source = graph_source + self.cluster.add_execution_profile(ep_name, ep) + server_timeout = 10000 + to_run = '''graph.schema().config().option("graph.traversal_sources.{0}.evaluation_timeout").set('{1} ms')'''.format( + graph_source, server_timeout) + self.session.execute_graph(to_run, execution_profile=ep_name) + + with self.assertRaises(Exception) as e: + self.session.execute_graph("java.util.concurrent.TimeUnit.MILLISECONDS.sleep(35000L);1+1", + execution_profile=ep_name) + self.assertTrue(isinstance(e, InvalidRequest) or isinstance(e, OperationTimedOut)) + + +@requiredse +class GraphProfileTests(BasicGraphUnitTestCase): + def test_graph_profile(self): + """ + Test verifying various aspects of graph config properties. + + @since 1.0.0 + @jira_ticket PYTHON-570 + + @test_category dse graph + """ + hosts = self.cluster.metadata.all_hosts() + first_host = hosts[0].address + second_hosts = "1.2.3.4" + + self._execute(ClassicGraphSchema.fixtures.classic(), graphson=GraphProtocol.GRAPHSON_1_0) + # Create various execution policies + exec_dif_factory = GraphExecutionProfile(row_factory=single_object_row_factory) + exec_dif_factory.graph_options.graph_name = self.graph_name + exec_dif_lbp = GraphExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy([first_host])) + exec_dif_lbp.graph_options.graph_name = self.graph_name + exec_bad_lbp = GraphExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy([second_hosts])) + exec_dif_lbp.graph_options.graph_name = self.graph_name + exec_short_timeout = GraphExecutionProfile(request_timeout=1, + load_balancing_policy=WhiteListRoundRobinPolicy([first_host])) + exec_short_timeout.graph_options.graph_name = self.graph_name + + # Add a single execution policy on cluster creation + local_cluster = TestCluster(execution_profiles={"exec_dif_factory": exec_dif_factory}) + local_session = local_cluster.connect() + self.addCleanup(local_cluster.shutdown) + + rs1 = self.session.execute_graph('g.V()') + rs2 = local_session.execute_graph('g.V()', execution_profile='exec_dif_factory') + + # Verify default and non default policy works + self.assertFalse(isinstance(rs2[0], Vertex)) + self.assertTrue(isinstance(rs1[0], Vertex)) + # Add other policies validate that lbp are honored + local_cluster.add_execution_profile("exec_dif_ldp", exec_dif_lbp) + local_session.execute_graph('g.V()', execution_profile="exec_dif_ldp") + local_cluster.add_execution_profile("exec_bad_lbp", exec_bad_lbp) + with self.assertRaises(NoHostAvailable): + local_session.execute_graph('g.V()', execution_profile="exec_bad_lbp") + + # Try with missing EP + with self.assertRaises(ValueError): + local_session.execute_graph('g.V()', execution_profile='bad_exec_profile') + + # Validate that timeout is honored + local_cluster.add_execution_profile("exec_short_timeout", exec_short_timeout) + with self.assertRaises(Exception) as e: + self.assertTrue(isinstance(e, InvalidRequest) or isinstance(e, OperationTimedOut)) + local_session.execute_graph('java.util.concurrent.TimeUnit.MILLISECONDS.sleep(2000L);', + execution_profile='exec_short_timeout') + + +@requiredse +class GraphMetadataTest(BasicGraphUnitTestCase): + + @greaterthanorequaldse51 + def test_dse_workloads(self): + """ + Test to ensure dse_workloads is populated appropriately. + Field added in DSE 5.1 + + @since DSE 2.0 + @jira_ticket PYTHON-667 + @expected_result dse_workloads set is set on host model + + @test_category metadata + """ + for host in self.cluster.metadata.all_hosts(): + self.assertIsInstance(host.dse_workloads, SortedSet) + self.assertIn("Cassandra", host.dse_workloads) + self.assertIn("Graph", host.dse_workloads) + + +@requiredse +class GraphExecutionProfileOptionsResolveTest(GraphUnitTestCase): + """ + Test that the execution profile options are properly resolved for graph queries. + + @since DSE 6.8 + @jira_ticket PYTHON-1004 PYTHON-1056 + @expected_result execution profile options are properly determined following the rules. + """ + + def test_default_options(self): + ep = self.session.get_execution_profile(EXEC_PROFILE_GRAPH_DEFAULT) + self.assertEqual(ep.graph_options.graph_protocol, None) + self.assertEqual(ep.row_factory, None) + self.session._resolve_execution_profile_options(ep) + self.assertEqual(ep.graph_options.graph_protocol, GraphProtocol.GRAPHSON_1_0) + self.assertEqual(ep.row_factory, graph_object_row_factory) + + def test_default_options_when_not_groovy(self): + ep = self.session.get_execution_profile(EXEC_PROFILE_GRAPH_DEFAULT) + self.assertEqual(ep.graph_options.graph_protocol, None) + self.assertEqual(ep.row_factory, None) + ep.graph_options.graph_language = 'whatever' + self.session._resolve_execution_profile_options(ep) + self.assertEqual(ep.graph_options.graph_protocol, GraphProtocol.GRAPHSON_2_0) + self.assertEqual(ep.row_factory, graph_graphson2_row_factory) + + def test_default_options_when_explicitly_specified(self): + ep = self.session.get_execution_profile(EXEC_PROFILE_GRAPH_DEFAULT) + self.assertEqual(ep.graph_options.graph_protocol, None) + self.assertEqual(ep.row_factory, None) + obj = object() + ep.graph_options.graph_protocol = obj + ep.row_factory = obj + self.session._resolve_execution_profile_options(ep) + self.assertEqual(ep.graph_options.graph_protocol, obj) + self.assertEqual(ep.row_factory, obj) + + @greaterthanorequaldse68 + def test_graph_protocol_default_for_core_is_graphson3(self): + """Test that graphson3 is automatically resolved for a core graph query""" + self.setup_graph(CoreGraphSchema) + ep = self.session.get_execution_profile(EXEC_PROFILE_GRAPH_DEFAULT) + self.assertEqual(ep.graph_options.graph_protocol, None) + self.assertEqual(ep.row_factory, None) + # Ensure we have the graph metadata + self.session.cluster.refresh_schema_metadata() + self.session._resolve_execution_profile_options(ep) + self.assertEqual(ep.graph_options.graph_protocol, GraphProtocol.GRAPHSON_3_0) + self.assertEqual(ep.row_factory, graph_graphson3_row_factory) + + self.execute_graph_queries(CoreGraphSchema.fixtures.classic(), verify_graphson=GraphProtocol.GRAPHSON_3_0) + + @greaterthanorequaldse68 + def test_graph_protocol_default_for_core_fallback_to_graphson1_if_no_graph_name(self): + """Test that graphson1 is set when we cannot detect if it's a core graph""" + self.setup_graph(CoreGraphSchema) + default_ep = self.session.get_execution_profile(EXEC_PROFILE_GRAPH_DEFAULT) + graph_options = default_ep.graph_options.copy() + graph_options.graph_name = None + ep = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, graph_options=graph_options) + self.session._resolve_execution_profile_options(ep) + self.assertEqual(ep.graph_options.graph_protocol, GraphProtocol.GRAPHSON_1_0) + self.assertEqual(ep.row_factory, graph_object_row_factory) + + regex = re.compile(".*Variable.*is unknown.*", re.S) + with self.assertRaisesRegex(SyntaxException, regex): + self.execute_graph_queries(CoreGraphSchema.fixtures.classic(), + execution_profile=ep, verify_graphson=GraphProtocol.GRAPHSON_1_0) diff --git a/tests/integration/advanced/graph/test_graph_cont_paging.py b/tests/integration/advanced/graph/test_graph_cont_paging.py new file mode 100644 index 0000000000..17c43c4e3d --- /dev/null +++ b/tests/integration/advanced/graph/test_graph_cont_paging.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cluster import ContinuousPagingOptions + +from tests.integration import greaterthanorequaldse68 +from tests.integration.advanced.graph import GraphUnitTestCase, CoreGraphSchema, GraphTestConfiguration + + +@greaterthanorequaldse68 +@GraphTestConfiguration.generate_tests(schema=CoreGraphSchema) +class GraphPagingTest(GraphUnitTestCase): + + def _setup_data(self, schema, graphson): + self.execute_graph("schema.vertexLabel('person').ifNotExists().partitionBy('name', Text).property('age', Int).create();", graphson) + for i in range(100): + self.execute_graph("g.addV('person').property('name', 'batman-{}')".format(i), graphson) + + def _test_cont_paging_is_enabled_by_default(self, schema, graphson): + """ + Test that graph paging is automatically enabled with a >=6.8 cluster. + + @jira_ticket PYTHON-1045 + @expected_result the response future has a continuous_paging_session since graph paging is enabled + + @test_category dse graph + """ + ep = self.get_execution_profile(graphson) + self._setup_data(schema, graphson) + rf = self.session.execute_graph_async("g.V()", execution_profile=ep) + results = list(rf.result()) + self.assertIsNotNone(rf._continuous_paging_session) + self.assertEqual(len(results), 100) + + def _test_cont_paging_can_be_disabled(self, schema, graphson): + """ + Test that graph paging can be disabled. + + @jira_ticket PYTHON-1045 + @expected_result the response future doesn't have a continuous_paging_session since graph paging is disabled + + @test_category dse graph + """ + ep = self.get_execution_profile(graphson) + new_ep = self.session.execution_profile_clone_update(ep, continuous_paging_options=None) + self._setup_data(schema, graphson) + rf = self.session.execute_graph_async("g.V()", execution_profile=new_ep) + results = list(rf.result()) + self.assertIsNone(rf._continuous_paging_session) + self.assertEqual(len(results), 100) + + def _test_cont_paging_with_custom_options(self, schema, graphson): + """ + Test that we can specify custom paging options. + + @jira_ticket PYTHON-1045 + @expected_result we get only the desired number of results + + @test_category dse graph + """ + ep = self.get_execution_profile(graphson) + new_ep = self.session.execution_profile_clone_update( + ep, continuous_paging_options=ContinuousPagingOptions(max_pages=1)) + self._setup_data(schema, graphson) + self.session.default_fetch_size = 10 + results = list(self.session.execute_graph("g.V()", execution_profile=new_ep)) + self.assertEqual(len(results), 10) diff --git a/tests/integration/advanced/graph/test_graph_datatype.py b/tests/integration/advanced/graph/test_graph_datatype.py new file mode 100644 index 0000000000..0fda2f0d44 --- /dev/null +++ b/tests/integration/advanced/graph/test_graph_datatype.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import time +import logging +from packaging.version import Version +from collections import namedtuple + +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT +from cassandra.graph import graph_result_row_factory +from cassandra.graph.query import GraphProtocol +from cassandra.graph.types import VertexProperty + +from tests.util import wait_until +from tests.integration.advanced.graph import BasicGraphUnitTestCase, ClassicGraphFixtures, \ + ClassicGraphSchema, CoreGraphSchema +from tests.integration.advanced.graph import VertexLabel, GraphTestConfiguration, GraphUnitTestCase +from tests.integration import DSE_VERSION, requiredse + +log = logging.getLogger(__name__) + + +@requiredse +class GraphBasicDataTypesTests(BasicGraphUnitTestCase): + + def test_result_types(self): + """ + Test to validate that the edge and vertex version of results are constructed correctly. + + @since 1.0.0 + @jira_ticket PYTHON-479 + @expected_result edge/vertex result types should be unpacked correctly. + @test_category dse graph + """ + queries, params = ClassicGraphFixtures.multiple_fields() + for query in queries: + self.session.execute_graph(query, params) + + prof = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, row_factory=graph_result_row_factory) # requires simplified row factory to avoid shedding id/~type information used for validation below + rs = self.session.execute_graph("g.V()", execution_profile=prof) + + for result in rs: + self._validate_type(result) + + def _validate_type(self, vertex): + for properties in vertex.properties.values(): + prop = properties[0] + + if DSE_VERSION >= Version("5.1"): + type_indicator = prop['id']['~label'] + else: + type_indicator = prop['id']['~type'] + + if any(type_indicator.startswith(t) for t in + ('int', 'short', 'long', 'bigint', 'decimal', 'smallint', 'varint')): + typ = int + elif any(type_indicator.startswith(t) for t in ('float', 'double')): + typ = float + elif any(type_indicator.startswith(t) for t in ('duration', 'date', 'negdate', 'time', + 'blob', 'timestamp', 'point', 'linestring', 'polygon', + 'inet', 'uuid')): + typ = str + else: + pass + self.fail("Received unexpected type: %s" % type_indicator) + self.assertIsInstance(prop['value'], typ) + + +class GenericGraphDataTypeTest(GraphUnitTestCase): + + def _test_all_datatypes(self, schema, graphson): + ep = self.get_execution_profile(graphson) + + for data in schema.fixtures.datatypes().values(): + typ, value, deserializer = data + vertex_label = VertexLabel([typ]) + property_name = next(iter(vertex_label.non_pk_properties.keys())) + schema.create_vertex_label(self.session, vertex_label, execution_profile=ep) + vertex = list(schema.add_vertex(self.session, vertex_label, property_name, value, execution_profile=ep))[0] + + def get_vertex_properties(): + return list(schema.get_vertex_properties( + self.session, vertex, execution_profile=ep)) + + prop_returned = 1 if DSE_VERSION < Version('5.1') else 2 # include pkid >=5.1 + wait_until( + lambda: len(get_vertex_properties()) == prop_returned, 0.2, 15) + + vertex_properties = get_vertex_properties() + if graphson == GraphProtocol.GRAPHSON_1_0: + vertex_properties = [vp.as_vertex_property() for vp in vertex_properties] + + for vp in vertex_properties: + if vp.label == 'pkid': + continue + + self.assertIsInstance(vp, VertexProperty) + self.assertEqual(vp.label, property_name) + if graphson == GraphProtocol.GRAPHSON_1_0: + deserialized_value = deserializer(vp.value) if deserializer else vp.value + self.assertEqual(deserialized_value, value) + else: + self.assertEqual(vp.value, value) + + def __test_udt(self, schema, graphson, address_class, address_with_tags_class, + complex_address_class, complex_address_with_owners_class): + if schema is not CoreGraphSchema or DSE_VERSION < Version('6.8'): + raise unittest.SkipTest("Graph UDT is only supported with DSE 6.8+ and Core graphs.") + + ep = self.get_execution_profile(graphson) + + Address = address_class + AddressWithTags = address_with_tags_class + ComplexAddress = complex_address_class + ComplexAddressWithOwners = complex_address_with_owners_class + + # setup udt + self.session.execute_graph(""" + schema.type('address').property('address', Text).property('city', Text).property('state', Text).create(); + schema.type('addressTags').property('address', Text).property('city', Text).property('state', Text). + property('tags', setOf(Text)).create(); + schema.type('complexAddress').property('address', Text).property('address_tags', frozen(typeOf('addressTags'))). + property('city', Text).property('state', Text).property('props', mapOf(Text, Int)).create(); + schema.type('complexAddressWithOwners').property('address', Text). + property('address_tags', frozen(typeOf('addressTags'))). + property('city', Text).property('state', Text).property('props', mapOf(Text, Int)). + property('owners', frozen(listOf(tupleOf(Text, Int)))).create(); + """, execution_profile=ep) + + time.sleep(2) # wait the UDT to be discovered + self.session.cluster.register_user_type(self.graph_name, 'address', Address) + self.session.cluster.register_user_type(self.graph_name, 'addressTags', AddressWithTags) + self.session.cluster.register_user_type(self.graph_name, 'complexAddress', ComplexAddress) + self.session.cluster.register_user_type(self.graph_name, 'complexAddressWithOwners', ComplexAddressWithOwners) + + data = { + "udt1": ["typeOf('address')", Address('1440 Rd Smith', 'Quebec', 'QC')], + "udt2": ["tupleOf(typeOf('address'), Text)", (Address('1440 Rd Smith', 'Quebec', 'QC'), 'hello')], + "udt3": ["tupleOf(frozen(typeOf('address')), Text)", (Address('1440 Rd Smith', 'Quebec', 'QC'), 'hello')], + "udt4": ["tupleOf(tupleOf(Int, typeOf('address')), Text)", + ((42, Address('1440 Rd Smith', 'Quebec', 'QC')), 'hello')], + "udt5": ["tupleOf(tupleOf(Int, typeOf('addressTags')), Text)", + ((42, AddressWithTags('1440 Rd Smith', 'Quebec', 'QC', {'t1', 't2'})), 'hello')], + "udt6": ["tupleOf(tupleOf(Int, typeOf('complexAddress')), Text)", + ((42, ComplexAddress('1440 Rd Smith', + AddressWithTags('1440 Rd Smith', 'Quebec', 'QC', {'t1', 't2'}), + 'Quebec', 'QC', {'p1': 42, 'p2': 33})), 'hello')], + "udt7": ["tupleOf(tupleOf(Int, frozen(typeOf('complexAddressWithOwners'))), Text)", + ((42, ComplexAddressWithOwners( + '1440 Rd Smith', + AddressWithTags('1440 CRd Smith', 'Quebec', 'QC', {'t1', 't2'}), + 'Quebec', 'QC', {'p1': 42, 'p2': 33}, [('Mike', 43), ('Gina', 39)]) + ), 'hello')] + } + + for typ, value in data.values(): + vertex_label = VertexLabel([typ]) + property_name = next(iter(vertex_label.non_pk_properties.keys())) + schema.create_vertex_label(self.session, vertex_label, execution_profile=ep) + + vertex = list(schema.add_vertex(self.session, vertex_label, property_name, value, execution_profile=ep))[0] + + def get_vertex_properties(): + return list(schema.get_vertex_properties( + self.session, vertex, execution_profile=ep)) + + wait_until( + lambda: len(get_vertex_properties()) == 2, 0.2, 15) + + vertex_properties = get_vertex_properties() + for vp in vertex_properties: + if vp.label == 'pkid': + continue + + self.assertIsInstance(vp, VertexProperty) + self.assertEqual(vp.label, property_name) + self.assertEqual(vp.value, value) + + def _test_udt_with_classes(self, schema, graphson): + class Address(object): + + def __init__(self, address, city, state): + self.address = address + self.city = city + self.state = state + + def __eq__(self, other): + return self.address == other.address and self.city == other.city and self.state == other.state + + class AddressWithTags(object): + + def __init__(self, address, city, state, tags): + self.address = address + self.city = city + self.state = state + self.tags = tags + + def __eq__(self, other): + return (self.address == other.address and self.city == other.city + and self.state == other.state and self.tags == other.tags) + + class ComplexAddress(object): + + def __init__(self, address, address_tags, city, state, props): + self.address = address + self.address_tags = address_tags + self.city = city + self.state = state + self.props = props + + def __eq__(self, other): + return (self.address == other.address and self.address_tags == other.address_tags + and self.city == other.city and self.state == other.state + and self.props == other.props) + + class ComplexAddressWithOwners(object): + + def __init__(self, address, address_tags, city, state, props, owners): + self.address = address + self.address_tags = address_tags + self.city = city + self.state = state + self.props = props + self.owners = owners + + def __eq__(self, other): + return (self.address == other.address and self.address_tags == other.address_tags + and self.city == other.city and self.state == other.state + and self.props == other.props and self.owners == other.owners) + + self.__test_udt(schema, graphson, Address, AddressWithTags, ComplexAddress, ComplexAddressWithOwners) + + def _test_udt_with_namedtuples(self, schema, graphson): + AddressTuple = namedtuple('Address', ('address', 'city', 'state')) + AddressWithTagsTuple = namedtuple('AddressWithTags', ('address', 'city', 'state', 'tags')) + ComplexAddressTuple = namedtuple('ComplexAddress', ('address', 'address_tags', 'city', 'state', 'props')) + ComplexAddressWithOwnersTuple = namedtuple('ComplexAddressWithOwners', ('address', 'address_tags', 'city', + 'state', 'props', 'owners')) + + self.__test_udt(schema, graphson, AddressTuple, AddressWithTagsTuple, + ComplexAddressTuple, ComplexAddressWithOwnersTuple) + + +@requiredse +@GraphTestConfiguration.generate_tests(schema=ClassicGraphSchema) +class ClassicGraphDataTypeTest(GenericGraphDataTypeTest): + pass + + +@requiredse +@GraphTestConfiguration.generate_tests(schema=CoreGraphSchema) +class CoreGraphDataTypeTest(GenericGraphDataTypeTest): + pass diff --git a/tests/integration/advanced/graph/test_graph_query.py b/tests/integration/advanced/graph/test_graph_query.py new file mode 100644 index 0000000000..5bad1b71c5 --- /dev/null +++ b/tests/integration/advanced/graph/test_graph_query.py @@ -0,0 +1,596 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from packaging.version import Version + +from copy import copy +from itertools import chain +import json +import time + +import unittest + +from cassandra import OperationTimedOut, ConsistencyLevel, InvalidRequest +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT, NoHostAvailable +from cassandra.protocol import ServerError, SyntaxException +from cassandra.query import QueryTrace +from cassandra.util import Point +from cassandra.graph import (SimpleGraphStatement, single_object_row_factory, + Result, GraphOptions, GraphProtocol, to_bigint) +from cassandra.datastax.graph.query import _graph_options +from cassandra.datastax.graph.types import T + +from tests.integration import DSE_VERSION, requiredse, greaterthanorequaldse68 +from tests.integration.advanced.graph import BasicGraphUnitTestCase, GraphTestConfiguration, \ + validate_classic_vertex, GraphUnitTestCase, validate_classic_edge, validate_path_result_type, \ + validate_line_edge, validate_generic_vertex_result_type, \ + ClassicGraphSchema, CoreGraphSchema, VertexLabel + + +@requiredse +class BasicGraphQueryTest(BasicGraphUnitTestCase): + + def test_consistency_passing(self): + """ + Test to validated that graph consistency levels are properly surfaced to the base driver + + @since 1.0.0 + @jira_ticket PYTHON-509 + @expected_result graph consistency levels are surfaced correctly + @test_category dse graph + """ + cl_attrs = ('graph_read_consistency_level', 'graph_write_consistency_level') + + # Iterates over the graph options and constructs an array containing + # The graph_options that correlate to graph read and write consistency levels + graph_params = [a[2] for a in _graph_options if a[0] in cl_attrs] + + s = self.session + default_profile = s.cluster.profile_manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT] + default_graph_opts = default_profile.graph_options + try: + # Checks the default graph attributes and ensures that both graph_read_consistency_level and graph_write_consistency_level + # Are None by default + for attr in cl_attrs: + self.assertIsNone(getattr(default_graph_opts, attr)) + + res = s.execute_graph("null") + for param in graph_params: + self.assertNotIn(param, res.response_future.message.custom_payload) + + # session defaults are passed + opts = GraphOptions() + opts.update(default_graph_opts) + cl = {0: ConsistencyLevel.ONE, 1: ConsistencyLevel.LOCAL_QUORUM} + for k, v in cl.items(): + setattr(opts, cl_attrs[k], v) + default_profile.graph_options = opts + + res = s.execute_graph("null") + + for k, v in cl.items(): + self.assertEqual(res.response_future.message.custom_payload[graph_params[k]], ConsistencyLevel.value_to_name[v].encode()) + + # passed profile values override session defaults + cl = {0: ConsistencyLevel.ALL, 1: ConsistencyLevel.QUORUM} + opts = GraphOptions() + opts.update(default_graph_opts) + for k, v in cl.items(): + attr_name = cl_attrs[k] + setattr(opts, attr_name, v) + self.assertNotEqual(getattr(default_profile.graph_options, attr_name), getattr(opts, attr_name)) + tmp_profile = s.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, graph_options=opts) + res = s.execute_graph("null", execution_profile=tmp_profile) + + for k, v in cl.items(): + self.assertEqual(res.response_future.message.custom_payload[graph_params[k]], ConsistencyLevel.value_to_name[v].encode()) + finally: + default_profile.graph_options = default_graph_opts + + def test_execute_graph_row_factory(self): + s = self.session + + # default Results + default_profile = s.cluster.profile_manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT] + self.assertEqual(default_profile.row_factory, None) # will be resolved to graph_object_row_factory + result = s.execute_graph("123")[0] + self.assertIsInstance(result, Result) + self.assertEqual(result.value, 123) + + # other via parameter + prof = s.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT, row_factory=single_object_row_factory) + rs = s.execute_graph("123", execution_profile=prof) + self.assertEqual(rs.response_future.row_factory, single_object_row_factory) + self.assertEqual(json.loads(rs[0]), {'result': 123}) + + def test_execute_graph_timeout(self): + s = self.session + + value = [1, 2, 3] + query = "[%r]" % (value,) + + # default is passed down + default_graph_profile = s.cluster.profile_manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT] + rs = self.session.execute_graph(query) + self.assertEqual(rs[0].value, value) + self.assertEqual(rs.response_future.timeout, default_graph_profile.request_timeout) + + # tiny timeout times out as expected + tmp_profile = copy(default_graph_profile) + tmp_profile.request_timeout = sys.float_info.min + + max_retry_count = 10 + for _ in range(max_retry_count): + start = time.time() + try: + with self.assertRaises(OperationTimedOut): + s.execute_graph(query, execution_profile=tmp_profile) + break + except: + end = time.time() + self.assertAlmostEqual(start, end, 1) + else: + raise Exception("session.execute_graph didn't time out in {0} tries".format(max_retry_count)) + + def test_profile_graph_options(self): + s = self.session + statement = SimpleGraphStatement("true") + ep = self.session.execution_profile_clone_update(EXEC_PROFILE_GRAPH_DEFAULT) + self.assertTrue(s.execute_graph(statement, execution_profile=ep)[0].value) + + # bad graph name to verify it's passed + ep.graph_options = ep.graph_options.copy() + ep.graph_options.graph_name = "definitely_not_correct" + try: + s.execute_graph(statement, execution_profile=ep) + except NoHostAvailable: + self.assertTrue(DSE_VERSION >= Version("6.0")) + except InvalidRequest: + self.assertTrue(DSE_VERSION >= Version("5.0")) + else: + if DSE_VERSION < Version("6.8"): # >6.8 returns true + self.fail("Should have risen ServerError or InvalidRequest") + + def test_additional_custom_payload(self): + s = self.session + custom_payload = {'some': 'example'.encode('utf-8'), 'items': 'here'.encode('utf-8')} + sgs = SimpleGraphStatement("null", custom_payload=custom_payload) + future = s.execute_graph_async(sgs) + + default_profile = s.cluster.profile_manager.profiles[EXEC_PROFILE_GRAPH_DEFAULT] + default_graph_opts = default_profile.graph_options + for k, v in chain(custom_payload.items(), default_graph_opts.get_options_map().items()): + self.assertEqual(future.message.custom_payload[k], v) + + +class GenericGraphQueryTest(GraphUnitTestCase): + + def _test_basic_query(self, schema, graphson): + """ + Test to validate that basic graph query results can be executed with a sane result set. + + Creates a simple classic tinkerpot graph, and attempts to find all vertices + related the vertex marco, that have a label of knows. + See reference graph here + http://www.tinkerpop.com/docs/3.0.0.M1/ + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result graph should find two vertices related to marco via 'knows' edges. + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + rs = self.execute_graph('''g.V().has('name','marko').out('knows').values('name')''', graphson) + self.assertFalse(rs.has_more_pages) + results_list = self.resultset_to_list(rs) + self.assertEqual(len(results_list), 2) + self.assertIn('vadas', results_list) + self.assertIn('josh', results_list) + + def _test_geometric_graph_types(self, schema, graphson): + """ + Test to validate that geometric types function correctly + + Creates a very simple graph, and tries to insert a simple point type + + @since 1.0.0 + @jira_ticket DSP-8087 + @expected_result json types associated with insert is parsed correctly + + @test_category dse graph + """ + vertex_label = VertexLabel([('pointP', "Point()")]) + ep = self.get_execution_profile(graphson) + schema.create_vertex_label(self.session, vertex_label, ep) + # import org.apache.cassandra.db.marshal.geometry.Point; + rs = schema.add_vertex(self.session, vertex_label, 'pointP', Point(0, 1), ep) + + # if result set is not parsed correctly this will throw an exception + self.assertIsNotNone(rs) + + def _test_execute_graph_trace(self, schema, graphson): + value = [1, 2, 3] + query = "[%r]" % (value,) + + # default is no trace + rs = self.execute_graph(query, graphson) + results = self.resultset_to_list(rs) + self.assertEqual(results[0], value) + self.assertIsNone(rs.get_query_trace()) + + # request trace + rs = self.execute_graph(query, graphson, trace=True) + results = self.resultset_to_list(rs) + self.assertEqual(results[0], value) + qt = rs.get_query_trace(max_wait_sec=10) + self.assertIsInstance(qt, QueryTrace) + self.assertIsNotNone(qt.duration) + + def _test_range_query(self, schema, graphson): + """ + Test to validate range queries are handled correctly. + + Creates a very large line graph script and executes it. Then proceeds to a range + limited query against it, and ensure that the results are formatted correctly and that + the result set is properly sized. + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result result set should be properly formatted and properly sized + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.line(150), graphson) + rs = self.execute_graph("g.E().range(0,10)", graphson) + self.assertFalse(rs.has_more_pages) + results = self.resultset_to_list(rs) + self.assertEqual(len(results), 10) + ep = self.get_execution_profile(graphson) + for result in results: + schema.ensure_properties(self.session, result, execution_profile=ep) + validate_line_edge(self, result) + + def _test_classic_graph(self, schema, graphson): + """ + Test to validate that basic graph generation, and vertex and edges are surfaced correctly + + Creates a simple classic tinkerpot graph, and iterates over the the vertices and edges + ensureing that each one is correct. See reference graph here + http://www.tinkerpop.com/docs/3.0.0.M1/ + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result graph should generate and all vertices and edge results should be + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + rs = self.execute_graph('g.V()', graphson) + ep = self.get_execution_profile(graphson) + for vertex in rs: + schema.ensure_properties(self.session, vertex, execution_profile=ep) + validate_classic_vertex(self, vertex) + rs = self.execute_graph('g.E()', graphson) + for edge in rs: + schema.ensure_properties(self.session, edge, execution_profile=ep) + validate_classic_edge(self, edge) + + def _test_graph_classic_path(self, schema, graphson): + """ + Test to validate that the path version of the result type is generated correctly. It also + tests basic path results as that is not covered elsewhere + + @since 1.0.0 + @jira_ticket PYTHON-479 + @expected_result path object should be unpacked correctly including all nested edges and verticies + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + rs = self.execute_graph("g.V().hasLabel('person').has('name', 'marko').as('a').outE('knows').inV().as('c', 'd')." + " outE('created').as('e', 'f', 'g').inV().path()", + graphson) + rs_list = list(rs) + self.assertEqual(len(rs_list), 2) + for result in rs_list: + try: + path = result.as_path() + except: + path = result + + ep = self.get_execution_profile(graphson) + for obj in path.objects: + schema.ensure_properties(self.session, obj, ep) + + validate_path_result_type(self, path) + + def _test_large_create_script(self, schema, graphson): + """ + Test to validate that server errors due to large groovy scripts are properly surfaced + + Creates a very large line graph script and executes it. Then proceeds to create a line graph script + that is to large for the server to handle expects a server error to be returned + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result graph should generate and all vertices and edge results should be + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.line(150), graphson) + self.execute_graph(schema.fixtures.line(300), graphson) # This should pass since the queries are split + self.assertRaises(SyntaxException, self.execute_graph, schema.fixtures.line(300, single_script=True), graphson) # this is not and too big + + def _test_large_result_set(self, schema, graphson): + """ + Test to validate that large result sets return correctly. + + Creates a very large graph. Ensures that large result sets are handled appropriately. + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result when limits of result sets are hit errors should be surfaced appropriately + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.large(), graphson, execution_profile_options={'request_timeout': 32}) + rs = self.execute_graph("g.V()", graphson) + for result in rs: + validate_generic_vertex_result_type(self, result) + + def _test_param_passing(self, schema, graphson): + """ + Test to validate that parameter passing works as expected + + @since 1.0.0 + @jira_ticket PYTHON-457 + @expected_result parameters work as expected + + @test_category dse graph + """ + + # unused parameters are passed, but ignored + self.execute_graph("null", graphson, params={"doesn't": "matter", "what's": "passed"}) + + # multiple params + rs = self.execute_graph("[a, b]", graphson, params={'a': 0, 'b': 1}) + results = self.resultset_to_list(rs) + self.assertEqual(results[0], 0) + self.assertEqual(results[1], 1) + + if graphson == GraphProtocol.GRAPHSON_1_0: + # different value types + for param in (None, "string", 1234, 5.678, True, False): + result = self.resultset_to_list(self.execute_graph('x', graphson, params={'x': param}))[0] + self.assertEqual(result, param) + + def _test_vertex_property_properties(self, schema, graphson): + """ + Test verifying vertex property properties + + @since 1.0.0 + @jira_ticket PYTHON-487 + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('skipped because rich properties are only supported with classic graphs') + + self.execute_graph("schema.propertyKey('k0').Text().ifNotExists().create();", graphson) + self.execute_graph("schema.propertyKey('k1').Text().ifNotExists().create();", graphson) + self.execute_graph("schema.propertyKey('key').Text().properties('k0', 'k1').ifNotExists().create();", graphson) + self.execute_graph("schema.vertexLabel('MLP').properties('key').ifNotExists().create();", graphson) + v = self.execute_graph('''v = graph.addVertex('MLP') + v.property('key', 'value', 'k0', 'v0', 'k1', 'v1') + v''', graphson)[0] + self.assertEqual(len(v.properties), 1) + self.assertEqual(len(v.properties['key']), 1) + p = v.properties['key'][0] + self.assertEqual(p.label, 'key') + self.assertEqual(p.value, 'value') + self.assertEqual(p.properties, {'k0': 'v0', 'k1': 'v1'}) + + def _test_vertex_multiple_properties(self, schema, graphson): + """ + Test verifying vertex property form for various Cardinality + + All key types are encoded as a list, regardless of cardinality + + Single cardinality properties have only one value -- the last one added + + Default is single (this is config dependent) + + @since 1.0.0 + @jira_ticket PYTHON-487 + + @test_category dse graph + """ + if schema is not ClassicGraphSchema: + raise unittest.SkipTest('skipped because multiple properties are only supported with classic graphs') + + self.execute_graph('''Schema schema = graph.schema(); + schema.propertyKey('mult_key').Text().multiple().ifNotExists().create(); + schema.propertyKey('single_key').Text().single().ifNotExists().create(); + schema.vertexLabel('MPW1').properties('mult_key').ifNotExists().create(); + schema.vertexLabel('SW1').properties('single_key').ifNotExists().create();''', graphson) + + v = self.execute_graph('''v = graph.addVertex('MPW1') + v.property('mult_key', 'value') + v''', graphson)[0] + self.assertEqual(len(v.properties), 1) + self.assertEqual(len(v.properties['mult_key']), 1) + self.assertEqual(v.properties['mult_key'][0].label, 'mult_key') + self.assertEqual(v.properties['mult_key'][0].value, 'value') + + # multiple_with_two_values + v = self.execute_graph('''g.addV('MPW1').property('mult_key', 'value0').property('mult_key', 'value1')''', graphson)[0] + self.assertEqual(len(v.properties), 1) + self.assertEqual(len(v.properties['mult_key']), 2) + self.assertEqual(v.properties['mult_key'][0].label, 'mult_key') + self.assertEqual(v.properties['mult_key'][1].label, 'mult_key') + self.assertEqual(v.properties['mult_key'][0].value, 'value0') + self.assertEqual(v.properties['mult_key'][1].value, 'value1') + + # single_with_one_value + v = self.execute_graph('''v = graph.addVertex('SW1') + v.property('single_key', 'value') + v''', graphson)[0] + self.assertEqual(len(v.properties), 1) + self.assertEqual(len(v.properties['single_key']), 1) + self.assertEqual(v.properties['single_key'][0].label, 'single_key') + self.assertEqual(v.properties['single_key'][0].value, 'value') + + if DSE_VERSION < Version('6.8'): + # single_with_two_values + with self.assertRaises(InvalidRequest): + v = self.execute_graph(''' + v = graph.addVertex('SW1') + v.property('single_key', 'value0').property('single_key', 'value1').next() + v + ''', graphson)[0] + else: + # >=6.8 single_with_two_values, first one wins + v = self.execute_graph('''v = graph.addVertex('SW1') + v.property('single_key', 'value0').property('single_key', 'value1') + v''', graphson)[0] + self.assertEqual(v.properties['single_key'][0].value, 'value0') + + def _test_result_forms(self, schema, graphson): + """ + Test to validate that geometric types function correctly + + Creates a very simple graph, and tries to insert a simple point type + + @since 1.0.0 + @jira_ticket DSP-8087 + @expected_result json types associated with insert is parsed correctly + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + ep = self.get_execution_profile(graphson) + + results = self.resultset_to_list(self.session.execute_graph('g.V()', execution_profile=ep)) + self.assertGreater(len(results), 0, "Result set was empty this was not expected") + for v in results: + schema.ensure_properties(self.session, v, ep) + validate_classic_vertex(self, v) + + results = self.resultset_to_list(self.session.execute_graph('g.E()', execution_profile=ep)) + self.assertGreater(len(results), 0, "Result set was empty this was not expected") + for e in results: + schema.ensure_properties(self.session, e, ep) + validate_classic_edge(self, e) + + def _test_query_profile(self, schema, graphson): + """ + Test to validate profiling results are deserialized properly. + + @since 1.6.0 + @jira_ticket PYTHON-1057 + @expected_result TraversalMetrics and Metrics are deserialized properly + + @test_category dse graph + """ + if graphson == GraphProtocol.GRAPHSON_1_0: + raise unittest.SkipTest('skipped because there is no metrics deserializer with graphson1') + + ep = self.get_execution_profile(graphson) + results = list(self.session.execute_graph("g.V().profile()", execution_profile=ep)) + self.assertEqual(len(results), 1) + self.assertIn('metrics', results[0]) + self.assertIn('dur', results[0]) + self.assertEqual(len(results[0]['metrics']), 2) + self.assertIn('dur', results[0]['metrics'][0]) + + def _test_query_bulkset(self, schema, graphson): + """ + Test to validate bulkset results are deserialized properly. + + @since 1.6.0 + @jira_ticket PYTHON-1060 + @expected_result BulkSet is deserialized properly to a list + + @test_category dse graph + """ + self.execute_graph(schema.fixtures.classic(), graphson) + ep = self.get_execution_profile(graphson) + results = list(self.session.execute_graph( + 'g.V().hasLabel("person").aggregate("x").by("age").cap("x")', + execution_profile=ep)) + self.assertEqual(len(results), 1) + results = results[0] + if type(results) is Result: + results = results.value + else: + self.assertEqual(len(results), 5) + self.assertEqual(results.count(35), 2) + + @greaterthanorequaldse68 + def _test_elementMap_query(self, schema, graphson): + """ + Test to validate that an elementMap can be serialized properly. + """ + self.execute_graph(schema.fixtures.classic(), graphson) + rs = self.execute_graph('''g.V().has('name','marko').elementMap()''', graphson) + results_list = self.resultset_to_list(rs) + self.assertEqual(len(results_list), 1) + row = results_list[0] + if graphson == GraphProtocol.GRAPHSON_3_0: + self.assertIn(T.id, row) + self.assertIn(T.label, row) + if schema is CoreGraphSchema: + self.assertEqual(row[T.id], 'dseg:/person/marko') + self.assertEqual(row[T.label], 'person') + else: + self.assertIn('id', row) + self.assertIn('label', row) + + +@GraphTestConfiguration.generate_tests(schema=ClassicGraphSchema) +class ClassicGraphQueryTest(GenericGraphQueryTest): + pass + + +@GraphTestConfiguration.generate_tests(schema=CoreGraphSchema) +class CoreGraphQueryTest(GenericGraphQueryTest): + pass + + +@GraphTestConfiguration.generate_tests(schema=CoreGraphSchema) +class CoreGraphQueryWithTypeWrapperTest(GraphUnitTestCase): + + def _test_basic_query_with_type_wrapper(self, schema, graphson): + """ + Test to validate that a query using a type wrapper works. + + @since 2.8.0 + @jira_ticket PYTHON-1051 + @expected_result graph query works and doesn't raise an exception + + @test_category dse graph + """ + ep = self.get_execution_profile(graphson) + vl = VertexLabel(['tupleOf(Int, Bigint)']) + schema.create_vertex_label(self.session, vl, execution_profile=ep) + + prop_name = next(iter(vl.non_pk_properties.keys())) + with self.assertRaises(InvalidRequest): + schema.add_vertex(self.session, vl, prop_name, (1, 42), execution_profile=ep) + + schema.add_vertex(self.session, vl, prop_name, (1, to_bigint(42)), execution_profile=ep) diff --git a/tests/integration/advanced/test_adv_metadata.py b/tests/integration/advanced/test_adv_metadata.py new file mode 100644 index 0000000000..80309302f0 --- /dev/null +++ b/tests/integration/advanced/test_adv_metadata.py @@ -0,0 +1,394 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from packaging.version import Version + +from tests.integration import (BasicExistingKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, + BasicSharedKeyspaceUnitTestCaseRF1, + greaterthanorequaldse51, greaterthanorequaldse60, + greaterthanorequaldse68, use_single_node, + DSE_VERSION, requiredse, TestCluster) + +import unittest + +import logging +import time + + +log = logging.getLogger(__name__) + + +def setup_module(): + if DSE_VERSION: + use_single_node() + + +@requiredse +@greaterthanorequaldse60 +class FunctionAndAggregateMetadataTests(BasicSharedKeyspaceUnitTestCaseRF1): + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + super(FunctionAndAggregateMetadataTests, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if DSE_VERSION: + super(FunctionAndAggregateMetadataTests, cls).setUpClass() + + def setUp(self): + self.func_name = self.function_table_name + '_func' + self.agg_name = self.function_table_name + '_agg(int)' + + def _populated_ks_meta_attr(self, attr_name): + val, start_time = None, time.time() + while not val: + self.cluster.refresh_schema_metadata() + val = getattr(self.cluster.metadata.keyspaces[self.keyspace_name], + attr_name) + self.assertLess(time.time(), start_time + 30, + 'did not see func in metadata in 30s') + log.debug('done blocking; dict is populated: {}'.format(val)) + return val + + def test_monotonic_on_and_deterministic_function(self): + self.session.execute(""" + CREATE FUNCTION {ksn}.{ftn}(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + DETERMINISTIC + MONOTONIC ON val + LANGUAGE java AS 'return key+val;'; + """.format(ksn=self.keyspace_name, + ftn=self.func_name)) + fn = self._populated_ks_meta_attr('functions')[ + '{}(int,int)'.format(self.func_name) + ] + self.assertEqual(fn.monotonic_on, ['val']) + # monotonic is not set by MONOTONIC ON + self.assertFalse(fn.monotonic) + self.assertTrue(fn.deterministic) + self.assertEqual('CREATE FUNCTION {ksn}.{ftn}(key int, val int) ' + 'RETURNS NULL ON NULL INPUT ' + 'RETURNS int DETERMINISTIC MONOTONIC ON val ' + 'LANGUAGE java AS $$return key+val;$$' + ''.format(ksn=self.keyspace_name, + ftn=self.func_name), + fn.as_cql_query()) + self.session.execute('DROP FUNCTION {}.{}'.format(self.keyspace_name, + self.func_name)) + self.session.execute(fn.as_cql_query()) + + def test_monotonic_all_and_nondeterministic_function(self): + self.session.execute(""" + CREATE FUNCTION {ksn}.{ftn}(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + MONOTONIC + LANGUAGE java AS 'return key+val;'; + """.format(ksn=self.keyspace_name, + ftn=self.func_name)) + fn = self._populated_ks_meta_attr('functions')[ + '{}(int,int)'.format(self.func_name) + ] + self.assertEqual(set(fn.monotonic_on), {'key', 'val'}) + self.assertTrue(fn.monotonic) + self.assertFalse(fn.deterministic) + self.assertEqual('CREATE FUNCTION {ksn}.{ftn}(key int, val int) ' + 'RETURNS NULL ON NULL INPUT RETURNS int MONOTONIC ' + 'LANGUAGE java AS $$return key+val;$$' + ''.format(ksn=self.keyspace_name, + ftn=self.func_name), + fn.as_cql_query()) + self.session.execute('DROP FUNCTION {}.{}'.format(self.keyspace_name, + self.func_name)) + self.session.execute(fn.as_cql_query()) + + def _create_func_for_aggregate(self): + self.session.execute(""" + CREATE FUNCTION {ksn}.{ftn}(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + DETERMINISTIC + LANGUAGE java AS 'return key+val;'; + """.format(ksn=self.keyspace_name, + ftn=self.func_name)) + + def test_deterministic_aggregate(self): + self._create_func_for_aggregate() + self.session.execute(""" + CREATE AGGREGATE {ksn}.{an} + SFUNC {ftn} + STYPE int + INITCOND 0 + DETERMINISTIC + """.format(ksn=self.keyspace_name, + ftn=self.func_name, + an=self.agg_name)) + ag = self._populated_ks_meta_attr('aggregates')[self.agg_name] + self.assertTrue(ag.deterministic) + self.assertEqual( + 'CREATE AGGREGATE {ksn}.{an} SFUNC ' + '{ftn} STYPE int INITCOND 0 DETERMINISTIC' + ''.format(ksn=self.keyspace_name, + ftn=self.func_name, + an=self.agg_name), + ag.as_cql_query()) + self.session.execute('DROP AGGREGATE {}.{}'.format(self.keyspace_name, + self.agg_name)) + self.session.execute(ag.as_cql_query()) + + def test_nondeterministic_aggregate(self): + self._create_func_for_aggregate() + self.session.execute(""" + CREATE AGGREGATE {ksn}.{an} + SFUNC {ftn} + STYPE int + INITCOND 0 + """.format(ksn=self.keyspace_name, + ftn=self.func_name, + an=self.agg_name)) + ag = self._populated_ks_meta_attr('aggregates')[self.agg_name] + self.assertFalse(ag.deterministic) + self.assertEqual( + 'CREATE AGGREGATE {ksn}.{an} SFUNC ' + '{ftn} STYPE int INITCOND 0' + ''.format(ksn=self.keyspace_name, + ftn=self.func_name, + an=self.agg_name), + ag.as_cql_query()) + self.session.execute('DROP AGGREGATE {}.{}'.format(self.keyspace_name, + self.agg_name)) + self.session.execute(ag.as_cql_query()) + + +@requiredse +class RLACMetadataTests(BasicSharedKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + super(RLACMetadataTests, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if DSE_VERSION: + super(RLACMetadataTests, cls).setUpClass() + + @greaterthanorequaldse51 + def test_rlac_on_table(self): + """ + Checks to ensure that the RLAC table extension appends the proper cql on export to tables + + @since 3.20 + @jira_ticket PYTHON-638 + @expected_result Invalid hosts on the contact list should be excluded + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.reports (" + " report_user text, " + " report_number int, " + " report_month int, " + " report_year int, " + " report_text text," + " PRIMARY KEY (report_user, report_number))".format(self.keyspace_name)) + restrict_cql = "RESTRICT ROWS ON {0}.reports USING report_user".format(self.keyspace_name) + self.session.execute(restrict_cql) + table_meta = self.cluster.metadata.keyspaces[self.keyspace_name].tables['reports'] + self.assertTrue(restrict_cql in table_meta.export_as_string()) + + @unittest.skip("Dse 5.1 doesn't support MV and RLAC remove after update") + @greaterthanorequaldse51 + def test_rlac_on_mv(self): + """ + Checks to ensure that the RLAC table extension appends the proper cql to export on mV's + + @since 3.20 + @jira_ticket PYTHON-682 + @expected_result Invalid hosts on the contact list should be excluded + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.reports2 (" + " report_user text, " + " report_number int, " + " report_month int, " + " report_year int, " + " report_text text," + " PRIMARY KEY (report_user, report_number))".format(self.keyspace_name)) + self.session.execute("CREATE MATERIALIZED VIEW {0}.reports_by_year AS " + " SELECT report_year, report_user, report_number, report_text FROM {0}.reports2 " + " WHERE report_user IS NOT NULL AND report_number IS NOT NULL AND report_year IS NOT NULL " + " PRIMARY KEY ((report_year, report_user), report_number)".format(self.keyspace_name)) + + restrict_cql_table = "RESTRICT ROWS ON {0}.reports2 USING report_user".format(self.keyspace_name) + self.session.execute(restrict_cql_table) + restrict_cql_view = "RESTRICT ROWS ON {0}.reports_by_year USING report_user".format(self.keyspace_name) + self.session.execute(restrict_cql_view) + table_cql = self.cluster.metadata.keyspaces[self.keyspace_name].tables['reports2'].export_as_string() + view_cql = self.cluster.metadata.keyspaces[self.keyspace_name].tables['reports2'].views["reports_by_year"].export_as_string() + self.assertTrue(restrict_cql_table in table_cql) + self.assertTrue(restrict_cql_view in table_cql) + self.assertTrue(restrict_cql_view in view_cql) + self.assertTrue(restrict_cql_table not in view_cql) + + +@requiredse +class NodeSyncMetadataTests(BasicSharedKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + super(NodeSyncMetadataTests, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if DSE_VERSION: + super(NodeSyncMetadataTests, cls).setUpClass() + + @greaterthanorequaldse60 + def test_nodesync_on_table(self): + """ + Checks to ensure that nodesync is visible through driver metadata + + @since 3.20 + @jira_ticket PYTHON-799 + @expected_result nodesync should be enabled + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.reports (" + " report_user text PRIMARY KEY" + ") WITH nodesync = {{" + "'enabled': 'true', 'deadline_target_sec' : 86400 }};".format( + self.keyspace_name + )) + table_meta = self.cluster.metadata.keyspaces[self.keyspace_name].tables['reports'] + self.assertIn('nodesync =', table_meta.export_as_string()) + self.assertIn('nodesync', table_meta.options) + + +@greaterthanorequaldse68 +class GraphMetadataTests(BasicExistingKeyspaceUnitTestCase): + """ + Various tests to ensure that graph metadata are visible through driver metadata + @since DSE6.8 + @jira_ticket PYTHON-996 + @expected_result graph metadata are fetched + @test_category metadata + """ + + @classmethod + def setUpClass(cls): + if DSE_VERSION and DSE_VERSION >= Version('6.8'): + super(GraphMetadataTests, cls).setUpClass() + cls.session.execute(""" + CREATE KEYSPACE ks_no_graph_engine WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; + """) + cls.session.execute(""" + CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} and graph_engine = 'Core'; + """ % (cls.ks_name,)) + + cls.session.execute(""" + CREATE TABLE %s.person (name text PRIMARY KEY) WITH VERTEX LABEL; + """ % (cls.ks_name,)) + + cls.session.execute(""" + CREATE TABLE %s.software(company text, name text, version int, PRIMARY KEY((company, name), version)) WITH VERTEX LABEL rocksolidsoftware; + """ % (cls.ks_name,)) + + cls.session.execute(""" + CREATE TABLE %s.contributors (contributor text, company_name text, software_name text, software_version int, + PRIMARY KEY (contributor, company_name, software_name, software_version) ) + WITH CLUSTERING ORDER BY (company_name ASC, software_name ASC, software_version ASC) + AND EDGE LABEL contrib FROM person(contributor) TO rocksolidsoftware((company_name, software_name), software_version); + """ % (cls.ks_name,)) + + @classmethod + def tearDownClass(cls): + if DSE_VERSION and DSE_VERSION >= Version('6.8'): + cls.session.execute('DROP KEYSPACE {0}'.format('ks_no_graph_engine')) + cls.session.execute('DROP KEYSPACE {0}'.format(cls.ks_name)) + cls.cluster.shutdown() + + def test_keyspace_metadata(self): + self.assertIsNone(self.cluster.metadata.keyspaces['ks_no_graph_engine'].graph_engine, None) + self.assertEqual(self.cluster.metadata.keyspaces[self.ks_name].graph_engine, 'Core') + + def test_keyspace_metadata_alter_graph_engine(self): + self.session.execute("ALTER KEYSPACE %s WITH graph_engine = 'Tinker'" % (self.ks_name,)) + self.assertEqual(self.cluster.metadata.keyspaces[self.ks_name].graph_engine, 'Tinker') + self.session.execute("ALTER KEYSPACE %s WITH graph_engine = 'Core'" % (self.ks_name,)) + self.assertEqual(self.cluster.metadata.keyspaces[self.ks_name].graph_engine, 'Core') + + def test_vertex_metadata(self): + vertex_meta = self.cluster.metadata.keyspaces[self.ks_name].tables['person'].vertex + self.assertEqual(vertex_meta.keyspace_name, self.ks_name) + self.assertEqual(vertex_meta.table_name, 'person') + self.assertEqual(vertex_meta.label_name, 'person') + + vertex_meta = self.cluster.metadata.keyspaces[self.ks_name].tables['software'].vertex + self.assertEqual(vertex_meta.keyspace_name, self.ks_name) + self.assertEqual(vertex_meta.table_name, 'software') + self.assertEqual(vertex_meta.label_name, 'rocksolidsoftware') + + def test_edge_metadata(self): + edge_meta = self.cluster.metadata.keyspaces[self.ks_name].tables['contributors'].edge + self.assertEqual(edge_meta.keyspace_name, self.ks_name) + self.assertEqual(edge_meta.table_name, 'contributors') + self.assertEqual(edge_meta.label_name, 'contrib') + self.assertEqual(edge_meta.from_table, 'person') + self.assertEqual(edge_meta.from_label, 'person') + self.assertEqual(edge_meta.from_partition_key_columns, ['contributor']) + self.assertEqual(edge_meta.from_clustering_columns, []) + self.assertEqual(edge_meta.to_table, 'software') + self.assertEqual(edge_meta.to_label, 'rocksolidsoftware') + self.assertEqual(edge_meta.to_partition_key_columns, ['company_name', 'software_name']) + self.assertEqual(edge_meta.to_clustering_columns, ['software_version']) + + +@greaterthanorequaldse68 +class GraphMetadataSchemaErrorTests(BasicExistingKeyspaceUnitTestCase): + """ + Test that we can connect when the graph schema is broken. + """ + + def test_connection_on_graph_schema_error(self): + self.session = self.cluster.connect() + + self.session.execute(""" + CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} and graph_engine = 'Core'; + """ % (self.ks_name,)) + + self.session.execute(""" + CREATE TABLE %s.person (name text PRIMARY KEY) WITH VERTEX LABEL; + """ % (self.ks_name,)) + + self.session.execute(""" + CREATE TABLE %s.software(company text, name text, version int, PRIMARY KEY((company, name), version)) WITH VERTEX LABEL rocksolidsoftware; + """ % (self.ks_name,)) + + self.session.execute(""" + CREATE TABLE %s.contributors (contributor text, company_name text, software_name text, software_version int, + PRIMARY KEY (contributor, company_name, software_name, software_version) ) + WITH CLUSTERING ORDER BY (company_name ASC, software_name ASC, software_version ASC) + AND EDGE LABEL contrib FROM person(contributor) TO rocksolidsoftware((company_name, software_name), software_version); + """ % (self.ks_name,)) + + self.session.execute('TRUNCATE system_schema.vertices') + TestCluster().connect().shutdown() diff --git a/tests/integration/advanced/test_auth.py b/tests/integration/advanced/test_auth.py new file mode 100644 index 0000000000..df1f385f74 --- /dev/null +++ b/tests/integration/advanced/test_auth.py @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import logging +import os +import subprocess +import time + +from ccmlib.dse_cluster import DseCluster +from nose.plugins.attrib import attr +from packaging.version import Version + +from cassandra.auth import (DSEGSSAPIAuthProvider, DSEPlainTextAuthProvider, + SaslAuthProvider, TransitionalModePlainTextAuthProvider) +from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT, NoHostAvailable +from cassandra.protocol import Unauthorized +from cassandra.query import SimpleStatement +from tests.integration import (get_cluster, greaterthanorequaldse51, + remove_cluster, requiredse, DSE_VERSION, TestCluster) +from tests.integration.advanced import ADS_HOME, use_single_node_with_graph +from tests.integration.advanced.graph import reset_graph, ClassicGraphFixtures + + +log = logging.getLogger(__name__) + + +def setup_module(): + if DSE_VERSION: + use_single_node_with_graph() + + +def teardown_module(): + if DSE_VERSION: + remove_cluster() # this test messes with config + + +def wait_role_manager_setup_then_execute(session, statements): + for s in statements: + exc = None + for attempt in range(3): + try: + session.execute(s) + break + except Exception as e: + exc = e + time.sleep(5) + else: # if we didn't reach `break` + if exc is not None: + raise exc + + +@attr('long') +@requiredse +class BasicDseAuthTest(unittest.TestCase): + + @classmethod + def setUpClass(self): + """ + This will setup the necessary infrastructure to run our authentication tests. It requires the ADS_HOME environment variable + and our custom embedded apache directory server jar in order to run. + """ + if not DSE_VERSION: + return + + clear_kerberos_tickets() + self.cluster = None + + # Setup variables for various keytab and other files + self.conf_file_dir = os.path.join(ADS_HOME, "conf/") + self.krb_conf = os.path.join(self.conf_file_dir, "krb5.conf") + self.dse_keytab = os.path.join(self.conf_file_dir, "dse.keytab") + self.dseuser_keytab = os.path.join(self.conf_file_dir, "dseuser.keytab") + self.cassandra_keytab = os.path.join(self.conf_file_dir, "cassandra.keytab") + self.bob_keytab = os.path.join(self.conf_file_dir, "bob.keytab") + self.charlie_keytab = os.path.join(self.conf_file_dir, "charlie.keytab") + actual_jar = os.path.join(ADS_HOME, "embedded-ads.jar") + + # Create configuration directories if they don't already exist + if not os.path.exists(self.conf_file_dir): + os.makedirs(self.conf_file_dir) + if not os.path.exists(actual_jar): + raise RuntimeError('could not find {}'.format(actual_jar)) + log.warning("Starting adserver") + # Start the ADS, this will create the keytab con configuration files listed above + self.proc = subprocess.Popen(['java', '-jar', actual_jar, '-k', '--confdir', self.conf_file_dir], shell=False) + time.sleep(10) + # TODO poll for server to come up + + log.warning("Starting adserver started") + ccm_cluster = get_cluster() + log.warning("fetching tickets") + # Stop cluster if running and configure it with the correct options + ccm_cluster.stop() + if isinstance(ccm_cluster, DseCluster): + # Setup kerberos options in cassandra.yaml + config_options = {'kerberos_options': {'keytab': self.dse_keytab, + 'service_principal': 'dse/_HOST@DATASTAX.COM', + 'qop': 'auth'}, + 'authentication_options': {'enabled': 'true', + 'default_scheme': 'kerberos', + 'scheme_permissions': 'true', + 'allow_digest_with_kerberos': 'true', + 'plain_text_without_ssl': 'warn', + 'transitional_mode': 'disabled'}, + 'authorization_options': {'enabled': 'true'}} + + krb5java = "-Djava.security.krb5.conf=" + self.krb_conf + # Setup dse authenticator in cassandra.yaml + ccm_cluster.set_configuration_options({ + 'authenticator': 'com.datastax.bdp.cassandra.auth.DseAuthenticator', + 'authorizer': 'com.datastax.bdp.cassandra.auth.DseAuthorizer' + }) + ccm_cluster.set_dse_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True, jvm_args=[krb5java]) + else: + log.error("Cluster is not dse cluster test will fail") + + @classmethod + def tearDownClass(self): + """ + Terminates running ADS (Apache directory server). + """ + if not DSE_VERSION: + return + + self.proc.terminate() + + def tearDown(self): + """ + This will clear any existing kerberos tickets by using kdestroy + """ + clear_kerberos_tickets() + if self.cluster: + self.cluster.shutdown() + + def refresh_kerberos_tickets(self, keytab_file, user_name, krb_conf): + """ + Fetches a new ticket for using the keytab file and username provided. + """ + self.ads_pid = subprocess.call(['kinit', '-t', keytab_file, user_name], env={'KRB5_CONFIG': krb_conf}, shell=False) + + def connect_and_query(self, auth_provider, query=None): + """ + Runs a simple system query with the auth_provided specified. + """ + os.environ['KRB5_CONFIG'] = self.krb_conf + self.cluster = TestCluster(auth_provider=auth_provider) + self.session = self.cluster.connect() + query = query if query else "SELECT * FROM system.local" + statement = SimpleStatement(query) + rs = self.session.execute(statement) + return rs + + def test_should_not_authenticate_with_bad_user_ticket(self): + """ + This tests will attempt to authenticate with a user that has a valid ticket, but is not a valid dse user. + @since 3.20 + @jira_ticket PYTHON-457 + @test_category dse auth + @expected_result NoHostAvailable exception should be thrown + + """ + self.refresh_kerberos_tickets(self.dseuser_keytab, "dseuser@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"]) + self.assertRaises(NoHostAvailable, self.connect_and_query, auth_provider) + + def test_should_not_authenticate_without_ticket(self): + """ + This tests will attempt to authenticate with a user that is valid but has no ticket + @since 3.20 + @jira_ticket PYTHON-457 + @test_category dse auth + @expected_result NoHostAvailable exception should be thrown + + """ + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"]) + self.assertRaises(NoHostAvailable, self.connect_and_query, auth_provider) + + def test_connect_with_kerberos(self): + """ + This tests will attempt to authenticate with a user that is valid and has a ticket + @since 3.20 + @jira_ticket PYTHON-457 + @test_category dse auth + @expected_result Client should be able to connect and run a basic query + + """ + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider() + rs = self.connect_and_query(auth_provider) + self.assertIsNotNone(rs) + connections = [c for holders in self.cluster.get_connection_holders() for c in holders.get_connections()] + # Check to make sure our server_authenticator class is being set appropriate + for connection in connections: + self.assertTrue('DseAuthenticator' in connection.authenticator.server_authenticator_class) + + def test_connect_with_kerberos_and_graph(self): + """ + This tests will attempt to authenticate with a user and execute a graph query + @since 3.20 + @jira_ticket PYTHON-457 + @test_category dse auth + @expected_result Client should be able to connect and run a basic graph query with authentication + + """ + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"]) + rs = self.connect_and_query(auth_provider) + self.assertIsNotNone(rs) + reset_graph(self.session, self._testMethodName.lower()) + profiles = self.cluster.profile_manager.profiles + profiles[EXEC_PROFILE_GRAPH_DEFAULT].graph_options.graph_name = self._testMethodName.lower() + self.session.execute_graph(ClassicGraphFixtures.classic()) + + rs = self.session.execute_graph('g.V()') + self.assertIsNotNone(rs) + + def test_connect_with_kerberos_host_not_resolved(self): + """ + This tests will attempt to authenticate with IP, this will fail on osx. + The success or failure of this test is dependent on a reverse dns lookup which can be impacted by your environment + if it fails don't panic. + @since 3.20 + @jira_ticket PYTHON-566 + @test_category dse auth + @expected_result Client should error when ip is used + + """ + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + DSEGSSAPIAuthProvider(service='dse', qops=["auth"], resolve_host_name=False) + + def test_connect_with_explicit_principal(self): + """ + This tests will attempt to authenticate using valid and invalid user principals + @since 3.20 + @jira_ticket PYTHON-574 + @test_category dse auth + @expected_result Client principals should be used by the underlying mechanism + + """ + + # Connect with valid principal + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="cassandra@DATASTAX.COM") + self.connect_and_query(auth_provider) + connections = [c for holders in self.cluster.get_connection_holders() for c in holders.get_connections()] + + # Check to make sure our server_authenticator class is being set appropriate + for connection in connections: + self.assertTrue('DseAuthenticator' in connection.authenticator.server_authenticator_class) + + # Use invalid principal + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="notauser@DATASTAX.COM") + self.assertRaises(NoHostAvailable, self.connect_and_query, auth_provider) + + @greaterthanorequaldse51 + def test_proxy_login_with_kerberos(self): + """ + Test that the proxy login works with kerberos. + """ + # Set up users for proxy login test + self._setup_for_proxy() + + query = "select * from testkrbproxy.testproxy" + + # Try normal login with Charlie + self.refresh_kerberos_tickets(self.charlie_keytab, "charlie@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="charlie@DATASTAX.COM") + self.connect_and_query(auth_provider, query=query) + + # Try proxy login with bob + self.refresh_kerberos_tickets(self.bob_keytab, "bob@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="bob@DATASTAX.COM", + authorization_id='charlie@DATASTAX.COM') + self.connect_and_query(auth_provider, query=query) + + # Try logging with bob without mentioning charlie + self.refresh_kerberos_tickets(self.bob_keytab, "bob@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="bob@DATASTAX.COM") + self.assertRaises(Unauthorized, self.connect_and_query, auth_provider, query=query) + + self._remove_proxy_setup() + + @greaterthanorequaldse51 + def test_proxy_login_with_kerberos_forbidden(self): + """ + Test that the proxy login fail when proxy role is not granted + """ + # Set up users for proxy login test + self._setup_for_proxy(False) + query = "select * from testkrbproxy.testproxy" + + # Try normal login with Charlie + self.refresh_kerberos_tickets(self.bob_keytab, "bob@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="bob@DATASTAX.COM", + authorization_id='charlie@DATASTAX.COM') + self.assertRaises(NoHostAvailable, self.connect_and_query, auth_provider, query=query) + + self.refresh_kerberos_tickets(self.bob_keytab, "bob@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal="bob@DATASTAX.COM") + self.assertRaises(Unauthorized, self.connect_and_query, auth_provider, query=query) + + self._remove_proxy_setup() + + def _remove_proxy_setup(self): + os.environ['KRB5_CONFIG'] = self.krb_conf + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal='cassandra@DATASTAX.COM') + cluster = TestCluster(auth_provider=auth_provider) + session = cluster.connect() + + session.execute("REVOKE PROXY.LOGIN ON ROLE '{0}' FROM '{1}'".format('charlie@DATASTAX.COM', 'bob@DATASTAX.COM')) + + session.execute("DROP ROLE IF EXISTS '{0}';".format('bob@DATASTAX.COM')) + session.execute("DROP ROLE IF EXISTS '{0}';".format('charlie@DATASTAX.COM')) + + # Create a keyspace and allow only charlie to query it. + + session.execute("DROP KEYSPACE testkrbproxy") + + cluster.shutdown() + + def _setup_for_proxy(self, grant=True): + os.environ['KRB5_CONFIG'] = self.krb_conf + self.refresh_kerberos_tickets(self.cassandra_keytab, "cassandra@DATASTAX.COM", self.krb_conf) + auth_provider = DSEGSSAPIAuthProvider(service='dse', qops=["auth"], principal='cassandra@DATASTAX.COM') + cluster = TestCluster(auth_provider=auth_provider) + session = cluster.connect() + + stmts = [ + "CREATE ROLE IF NOT EXISTS '{0}' WITH LOGIN = TRUE;".format('bob@DATASTAX.COM'), + "CREATE ROLE IF NOT EXISTS '{0}' WITH LOGIN = TRUE;".format('bob@DATASTAX.COM'), + "GRANT EXECUTE ON ALL AUTHENTICATION SCHEMES to 'bob@DATASTAX.COM'", + "CREATE ROLE IF NOT EXISTS '{0}' WITH LOGIN = TRUE;".format('charlie@DATASTAX.COM'), + "GRANT EXECUTE ON ALL AUTHENTICATION SCHEMES to 'charlie@DATASTAX.COM'", + # Create a keyspace and allow only charlie to query it. + "CREATE KEYSPACE testkrbproxy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}", + "CREATE TABLE testkrbproxy.testproxy (id int PRIMARY KEY, value text)", + "GRANT ALL PERMISSIONS ON KEYSPACE testkrbproxy to '{0}'".format('charlie@DATASTAX.COM'), + ] + + if grant: + stmts.append("GRANT PROXY.LOGIN ON ROLE '{0}' to '{1}'".format('charlie@DATASTAX.COM', 'bob@DATASTAX.COM')) + + wait_role_manager_setup_then_execute(session, stmts) + + cluster.shutdown() + + +def clear_kerberos_tickets(): + subprocess.call(['kdestroy'], shell=False) + + +@attr('long') +@requiredse +class BaseDseProxyAuthTest(unittest.TestCase): + + @classmethod + def setUpClass(self): + """ + This will setup the necessary infrastructure to run unified authentication tests. + """ + if not DSE_VERSION or DSE_VERSION < Version('5.1'): + return + self.cluster = None + + ccm_cluster = get_cluster() + # Stop cluster if running and configure it with the correct options + ccm_cluster.stop() + if isinstance(ccm_cluster, DseCluster): + # Setup dse options in dse.yaml + config_options = {'authentication_options': {'enabled': 'true', + 'default_scheme': 'internal', + 'scheme_permissions': 'true', + 'transitional_mode': 'normal'}, + 'authorization_options': {'enabled': 'true'} + } + + # Setup dse authenticator in cassandra.yaml + ccm_cluster.set_configuration_options({ + 'authenticator': 'com.datastax.bdp.cassandra.auth.DseAuthenticator', + 'authorizer': 'com.datastax.bdp.cassandra.auth.DseAuthorizer' + }) + ccm_cluster.set_dse_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + else: + log.error("Cluster is not dse cluster test will fail") + + # Create users and test keyspace + self.user_role = 'user1' + self.server_role = 'server' + self.root_cluster = TestCluster(auth_provider=DSEPlainTextAuthProvider('cassandra', 'cassandra')) + self.root_session = self.root_cluster.connect() + + stmts = [ + "CREATE USER {0} WITH PASSWORD '{1}'".format(self.server_role, self.server_role), + "CREATE USER {0} WITH PASSWORD '{1}'".format(self.user_role, self.user_role), + "CREATE KEYSPACE testproxy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}", + "CREATE TABLE testproxy.testproxy (id int PRIMARY KEY, value text)", + "GRANT ALL PERMISSIONS ON KEYSPACE testproxy to {0}".format(self.user_role) + ] + + wait_role_manager_setup_then_execute(self.root_session, stmts) + + @classmethod + def tearDownClass(self): + """ + Shutdown the root session. + """ + if not DSE_VERSION or DSE_VERSION < Version('5.1'): + return + self.root_session.execute('DROP KEYSPACE testproxy;') + self.root_session.execute('DROP USER {0}'.format(self.user_role)) + self.root_session.execute('DROP USER {0}'.format(self.server_role)) + self.root_cluster.shutdown() + + def tearDown(self): + """ + Shutdown the cluster and reset proxy permissions + """ + self.cluster.shutdown() + + self.root_session.execute("REVOKE PROXY.LOGIN ON ROLE {0} from {1}".format(self.user_role, self.server_role)) + self.root_session.execute("REVOKE PROXY.EXECUTE ON ROLE {0} from {1}".format(self.user_role, self.server_role)) + + def grant_proxy_login(self): + """ + Grant PROXY.LOGIN permission on a role to a specific user. + """ + self.root_session.execute("GRANT PROXY.LOGIN on role {0} to {1}".format(self.user_role, self.server_role)) + + def grant_proxy_execute(self): + """ + Grant PROXY.EXECUTE permission on a role to a specific user. + """ + self.root_session.execute("GRANT PROXY.EXECUTE on role {0} to {1}".format(self.user_role, self.server_role)) + + +@attr('long') +@greaterthanorequaldse51 +class DseProxyAuthTest(BaseDseProxyAuthTest): + """ + Tests Unified Auth. Proxy Login using SASL and Proxy Execute. + """ + + @classmethod + def get_sasl_options(self, mechanism='PLAIN'): + sasl_options = { + "service": 'dse', + "username": 'server', + "mechanism": mechanism, + 'password': self.server_role, + 'authorization_id': self.user_role + } + return sasl_options + + def connect_and_query(self, auth_provider, execute_as=None, query="SELECT * FROM testproxy.testproxy"): + self.cluster = TestCluster(auth_provider=auth_provider) + self.session = self.cluster.connect() + rs = self.session.execute(query, execute_as=execute_as) + return rs + + def test_proxy_login_forbidden(self): + """ + Test that a proxy login is forbidden by default for a user. + @since 3.20 + @jira_ticket PYTHON-662 + @test_category dse auth + @expected_result connect and query should not be allowed + """ + auth_provider = SaslAuthProvider(**self.get_sasl_options()) + with self.assertRaises(Unauthorized): + self.connect_and_query(auth_provider) + + def test_proxy_login_allowed(self): + """ + Test that a proxy login is allowed with proper permissions. + @since 3.20 + @jira_ticket PYTHON-662 + @test_category dse auth + @expected_result connect and query should be allowed + """ + auth_provider = SaslAuthProvider(**self.get_sasl_options()) + self.grant_proxy_login() + self.connect_and_query(auth_provider) + + def test_proxy_execute_forbidden(self): + """ + Test that a proxy execute is forbidden by default for a user. + @since 3.20 + @jira_ticket PYTHON-662 + @test_category dse auth + @expected_result connect and query should not be allowed + """ + auth_provider = DSEPlainTextAuthProvider(self.server_role, self.server_role) + with self.assertRaises(Unauthorized): + self.connect_and_query(auth_provider, execute_as=self.user_role) + + def test_proxy_execute_allowed(self): + """ + Test that a proxy execute is allowed with proper permissions. + @since 3.20 + @jira_ticket PYTHON-662 + @test_category dse auth + @expected_result connect and query should be allowed + """ + auth_provider = DSEPlainTextAuthProvider(self.server_role, self.server_role) + self.grant_proxy_execute() + self.connect_and_query(auth_provider, execute_as=self.user_role) + + def test_connection_with_transitional_mode(self): + """ + Test that the driver can connect using TransitionalModePlainTextAuthProvider + @since 3.20 + @jira_ticket PYTHON-831 + @test_category dse auth + @expected_result connect and query should be allowed + """ + auth_provider = TransitionalModePlainTextAuthProvider() + self.assertIsNotNone(self.connect_and_query(auth_provider, query="SELECT * from system.local")) diff --git a/tests/integration/advanced/test_cont_paging.py b/tests/integration/advanced/test_cont_paging.py new file mode 100644 index 0000000000..0f64835674 --- /dev/null +++ b/tests/integration/advanced/test_cont_paging.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.integration import use_singledc, greaterthanorequaldse51, BasicSharedKeyspaceUnitTestCaseRF3WM, \ + DSE_VERSION, ProtocolVersion, greaterthanorequaldse60, requiredse, TestCluster + +import logging +log = logging.getLogger(__name__) + +import unittest + +from itertools import cycle, count +from packaging.version import Version +import time + +from cassandra.cluster import ExecutionProfile, ContinuousPagingOptions +from cassandra.concurrent import execute_concurrent +from cassandra.query import SimpleStatement + + +def setup_module(): + if DSE_VERSION: + use_singledc() + + +@requiredse +class BaseContPagingTests(): + @classmethod + def setUpClass(cls): + if not DSE_VERSION or DSE_VERSION < cls.required_dse_version: + return + + cls.execution_profiles = {"CONTDEFAULT": ExecutionProfile(continuous_paging_options=ContinuousPagingOptions()), + "ONEPAGE": ExecutionProfile( + continuous_paging_options=ContinuousPagingOptions(max_pages=1)), + "MANYPAGES": ExecutionProfile( + continuous_paging_options=ContinuousPagingOptions(max_pages=10)), + "BYTES": ExecutionProfile(continuous_paging_options=ContinuousPagingOptions( + page_unit=ContinuousPagingOptions.PagingUnit.BYTES)), + "SLOW": ExecutionProfile( + continuous_paging_options=ContinuousPagingOptions(max_pages_per_second=1)), } + cls.sane_eps = ["CONTDEFAULT", "BYTES"] + + @classmethod + def tearDownClass(cls): + if not DSE_VERSION or DSE_VERSION < cls.required_dse_version: + return + + @classmethod + def create_cluster(cls): + + cls.cluster_with_profiles = TestCluster(protocol_version=cls.protocol_version, execution_profiles=cls.execution_profiles) + + cls.session_with_profiles = cls.cluster_with_profiles.connect(wait_for_all_pools=True) + statements_and_params = zip( + cycle(["INSERT INTO " + cls.ks_name + "." + cls.ks_name + " (k, v) VALUES (%s, 0)"]), + [(i,) for i in range(150)]) + execute_concurrent(cls.session_with_profiles, list(statements_and_params)) + + cls.select_all_statement = "SELECT * FROM {0}.{0}".format(cls.ks_name) + + def test_continuous_paging(self): + """ + Test to ensure that various continuous paging schemes return the full set of results. + @since 3.20 + @jira_ticket PYTHON-615 + @expected_result various continuous paging options should fetch all the results + + @test_category queries + """ + for ep in self.execution_profiles.keys(): + results = list(self.session_with_profiles.execute(self.select_all_statement, execution_profile= ep)) + self.assertEqual(len(results), 150) + + + def test_page_fetch_size(self): + """ + Test to ensure that continuous paging works appropriately with fetch size. + @since 3.20 + @jira_ticket PYTHON-615 + @expected_result continuous paging options should work sensibly with various fetch size + + @test_category queries + """ + + # Since we fetch one page at a time results should match fetch size + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 150): + self.session_with_profiles.default_fetch_size = fetch_size + results = list(self.session_with_profiles.execute(self.select_all_statement, execution_profile= "ONEPAGE")) + self.assertEqual(len(results), fetch_size) + + # Since we fetch ten pages at a time results should match fetch size * 10 + for fetch_size in (2, 3, 7, 10, 15): + self.session_with_profiles.default_fetch_size = fetch_size + results = list(self.session_with_profiles.execute(self.select_all_statement, execution_profile= "MANYPAGES")) + self.assertEqual(len(results), fetch_size*10) + + # Default settings for continuous paging should be able to fetch all results regardless of fetch size + # Changing the units should, not affect the number of results, if max_pages is not set + for profile in self.sane_eps: + for fetch_size in (2, 3, 7, 10, 15): + self.session_with_profiles.default_fetch_size = fetch_size + results = list(self.session_with_profiles.execute(self.select_all_statement, execution_profile= profile)) + self.assertEqual(len(results), 150) + + # This should take around 3 seconds to fetch but should still complete with all results + self.session_with_profiles.default_fetch_size = 50 + results = list(self.session_with_profiles.execute(self.select_all_statement, execution_profile= "SLOW")) + self.assertEqual(len(results), 150) + + def test_paging_cancel(self): + """ + Test to ensure we can cancel a continuous paging session once it's started + @since 3.20 + @jira_ticket PYTHON-615 + @expected_result This query should be canceled before any sizable amount of results can be returned + @test_category queries + """ + + self.session_with_profiles.default_fetch_size = 1 + # This combination should fetch one result a second. We should see a very few results + results = self.session_with_profiles.execute_async(self.select_all_statement, execution_profile= "SLOW") + result_set = results.result() + result_set.cancel_continuous_paging() + result_lst = list(result_set) + self.assertLess(len(result_lst), 2, "Cancel should have aborted fetch immediately") + + def test_con_paging_verify_writes(self): + """ + Test to validate results with a few continuous paging options + @since 3.20 + @jira_ticket PYTHON-615 + @expected_result all results should be returned correctly + @test_category queries + """ + prepared = self.session_with_profiles.prepare(self.select_all_statement) + + + for ep in self.sane_eps: + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): + self.session_with_profiles.default_fetch_size = fetch_size + results = self.session_with_profiles.execute(self.select_all_statement, execution_profile=ep) + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(150)), result_array) + self.assertEqual(set([0]), result_set) + + statement = SimpleStatement(self.select_all_statement) + results = self.session_with_profiles.execute(statement, execution_profile=ep) + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(150)), result_array) + self.assertEqual(set([0]), result_set) + + results = self.session_with_profiles.execute(prepared, execution_profile=ep) + result_array = set() + result_set = set() + for result in results: + result_array.add(result.k) + result_set.add(result.v) + + self.assertEqual(set(range(150)), result_array) + self.assertEqual(set([0]), result_set) + + def test_can_get_results_when_no_more_pages(self): + """ + Test to validate that the results can be fetched when + has_more_pages is False + @since 3.20 + @jira_ticket PYTHON-946 + @expected_result the results can be fetched + @test_category queries + """ + generator_expanded = [] + def get_all_rows(generator, future, generator_expanded): + self.assertFalse(future.has_more_pages) + + generator_expanded.extend(list(generator)) + print("Setting generator_expanded to True") + + future = self.session_with_profiles.execute_async("SELECT * from system.local LIMIT 10", + execution_profile="CONTDEFAULT") + future.add_callback(get_all_rows, future, generator_expanded) + time.sleep(5) + self.assertTrue(generator_expanded) + + +@requiredse +@greaterthanorequaldse51 +class ContPagingTestsDSEV1(BaseContPagingTests, BasicSharedKeyspaceUnitTestCaseRF3WM): + @classmethod + def setUpClass(cls): + cls.required_dse_version = BaseContPagingTests.required_dse_version = Version('5.1') + if not DSE_VERSION or DSE_VERSION < cls.required_dse_version: + return + + BasicSharedKeyspaceUnitTestCaseRF3WM.setUpClass() + BaseContPagingTests.setUpClass() + + cls.protocol_version = ProtocolVersion.DSE_V1 + cls.create_cluster() + + +@requiredse +@greaterthanorequaldse60 +class ContPagingTestsDSEV2(BaseContPagingTests, BasicSharedKeyspaceUnitTestCaseRF3WM): + @classmethod + def setUpClass(cls): + cls.required_dse_version = BaseContPagingTests.required_dse_version = Version('6.0') + if not DSE_VERSION or DSE_VERSION < cls.required_dse_version: + return + + BasicSharedKeyspaceUnitTestCaseRF3WM.setUpClass() + BaseContPagingTests.setUpClass() + + more_profiles = { + "SMALL_QUEUE": ExecutionProfile(continuous_paging_options=ContinuousPagingOptions(max_queue_size=2)), + "BIG_QUEUE": ExecutionProfile(continuous_paging_options=ContinuousPagingOptions(max_queue_size=400)) + } + cls.sane_eps += ["SMALL_QUEUE", "BIG_QUEUE"] + cls.execution_profiles.update(more_profiles) + + cls.protocol_version = ProtocolVersion.DSE_V2 + cls.create_cluster() diff --git a/tests/integration/advanced/test_cqlengine_where_operators.py b/tests/integration/advanced/test_cqlengine_where_operators.py new file mode 100644 index 0000000000..b39cde0f02 --- /dev/null +++ b/tests/integration/advanced/test_cqlengine_where_operators.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import os +import time + +from cassandra.cqlengine import columns, connection, models +from cassandra.cqlengine.management import (CQLENG_ALLOW_SCHEMA_MANAGEMENT, + create_keyspace_simple, drop_table, + sync_table) +from cassandra.cqlengine.statements import IsNotNull +from tests.integration import DSE_VERSION, requiredse, CASSANDRA_IP, greaterthanorequaldse60, TestCluster +from tests.integration.advanced import use_single_node_with_graph_and_solr +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class SimpleNullableModel(models.Model): + __keyspace__ = DEFAULT_KEYSPACE + partition = columns.Integer(primary_key=True) + nullable = columns.Integer(required=False) + # nullable = columns.Integer(required=False, custom_index=True) + + +def setup_module(): + if DSE_VERSION: + os.environ[CQLENG_ALLOW_SCHEMA_MANAGEMENT] = '1' + use_single_node_with_graph_and_solr() + setup_connection(DEFAULT_KEYSPACE) + create_keyspace_simple(DEFAULT_KEYSPACE, 1) + sync_table(SimpleNullableModel) + + +def setup_connection(keyspace_name): + connection.setup([CASSANDRA_IP], + # consistency=ConsistencyLevel.ONE, + # protocol_version=PROTOCOL_VERSION, + default_keyspace=keyspace_name) + + +def teardown_module(): + if DSE_VERSION: + drop_table(SimpleNullableModel) + + +@requiredse +class IsNotNullTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + cls.cluster = TestCluster() + + @greaterthanorequaldse60 + def test_is_not_null_execution(self): + """ + Verify that CQL statements have correct syntax when executed + If we wanted them to return something meaningful and not a InvalidRequest + we'd have to create an index in search for the column we are using + IsNotNull + + @since 3.20 + @jira_ticket PYTHON-968 + @expected_result InvalidRequest is arisen + + @test_category cqlengine + """ + cluster = TestCluster() + self.addCleanup(cluster.shutdown) + session = cluster.connect() + + SimpleNullableModel.create(partition=1, nullable=2) + SimpleNullableModel.create(partition=2, nullable=None) + + self.addCleanup(session.execute, "DROP SEARCH INDEX ON {}".format( + SimpleNullableModel.column_family_name())) + create_index_stmt = ( + "CREATE SEARCH INDEX ON {} WITH COLUMNS nullable " + "".format(SimpleNullableModel.column_family_name())) + session.execute(create_index_stmt) + + SimpleNullableModel.create(partition=1, nullable=1) + SimpleNullableModel.create(partition=2, nullable=None) + + # TODO: block on indexing more precisely + time.sleep(5) + + self.assertEqual(len(list(SimpleNullableModel.objects.all())), 2) + self.assertEqual( + len(list( + SimpleNullableModel.filter(IsNotNull("nullable"), partition__eq=2) + )), + 0) + self.assertEqual( + len(list( + SimpleNullableModel.filter(IsNotNull("nullable"), partition__eq=1) + )), + 1) diff --git a/tests/integration/advanced/test_geometry.py b/tests/integration/advanced/test_geometry.py new file mode 100644 index 0000000000..3bbf04bb7a --- /dev/null +++ b/tests/integration/advanced/test_geometry.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests.integration import DSE_VERSION, requiredse +from tests.integration.advanced import BasicGeometricUnitTestCase, use_single_node_with_graph +from cassandra.util import OrderedMap, sortedset +from collections import namedtuple + +import unittest +from uuid import uuid1 +from cassandra.util import Point, LineString, Polygon +from cassandra.cqltypes import LineStringType, PointType, PolygonType + + +def setup_module(): + if DSE_VERSION: + use_single_node_with_graph() + + +class AbstractGeometricTypeTest(): + + original_value = "" + + def test_should_insert_simple(self): + """ + This tests will attempt to insert a point, polygon, or line, using simple inline formatting. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried. + """ + uuid_key = uuid1() + self.session.execute("INSERT INTO tbl (k, g) VALUES (%s, %s)", [uuid_key, self.original_value]) + self.validate('g', uuid_key, self.original_value) + + def test_should_insert_simple_prepared(self): + """ + This tests will attempt to insert a point, polygon, or line, using prepared statements. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, g) VALUES (?, ?)") + self.session.execute(prepared, (uuid_key, self.original_value)) + self.validate('g', uuid_key, self.original_value) + + def test_should_insert_simple_prepared_with_bound(self): + """ + This tests will attempt to insert a point, polygon, or line, using prepared statements and bind. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, g) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, self.original_value)) + self.session.execute(bound_statement) + self.validate('g', uuid_key, self.original_value) + + def test_should_insert_as_list(self): + """ + This tests will attempt to insert a point, polygon, or line, as values of list. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as a list. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, l) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, [self.original_value])) + self.session.execute(bound_statement) + self.validate('l', uuid_key, [self.original_value]) + + def test_should_insert_as_set(self): + """ + This tests will attempt to insert a point, polygon, or line, as values of set. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as a set. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, s) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, sortedset([self.original_value]))) + self.session.execute(bound_statement) + self.validate('s', uuid_key, sortedset([self.original_value])) + + def test_should_insert_as_map_keys(self): + """ + This tests will attempt to insert a point, polygon, or line, as keys of a map. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as keys of a map. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, m0) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, OrderedMap(zip([self.original_value], [1])))) + self.session.execute(bound_statement) + self.validate('m0', uuid_key, OrderedMap(zip([self.original_value], [1]))) + + def test_should_insert_as_map_values(self): + """ + This tests will attempt to insert a point, polygon, or line, as values of a map. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as values of a map. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, m1) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, OrderedMap(zip([1], [self.original_value])))) + self.session.execute(bound_statement) + self.validate('m1', uuid_key, OrderedMap(zip([1], [self.original_value]))) + + def test_should_insert_as_tuple(self): + """ + This tests will attempt to insert a point, polygon, or line, as values of a tuple. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as values of a tuple. + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, t) VALUES (?, ?)") + bound_statement = prepared.bind((uuid_key, (self.original_value, self.original_value, self.original_value))) + self.session.execute(bound_statement) + self.validate('t', uuid_key, (self.original_value, self.original_value, self.original_value)) + + def test_should_insert_as_udt(self): + """ + This tests will attempt to insert a point, polygon, or line, as members of a udt. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as members of a udt. + """ + UDT1 = namedtuple('udt1', ('g')) + self.cluster.register_user_type(self.ks_name, 'udt1', UDT1) + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, u) values (?, ?)") + bound_statement = prepared.bind((uuid_key, UDT1(self.original_value))) + self.session.execute(bound_statement) + rs = self.session.execute("SELECT {0} from {1} where k={2}".format('u', 'tbl', uuid_key)) + retrieved_udt = rs[0]._asdict()['u'] + + self.assertEqual(retrieved_udt.g, self.original_value) + + def test_should_accept_as_partition_key(self): + """ + This tests will attempt to insert a point, polygon, or line, as a partition key. + @since 3.20 + @jira_ticket PYTHON-456 + @test_category dse geometric + @expected_result geometric types should be able to be inserted and queried as a partition key. + """ + prepared = self.session.prepare("INSERT INTO tblpk (k, v) VALUES (?, ?)") + bound_statement = prepared.bind((self.original_value, 1)) + self.session.execute(bound_statement) + rs = self.session.execute("SELECT k, v FROM tblpk") + foundpk = rs[0]._asdict()['k'] + self.assertEqual(foundpk, self.original_value) + + def validate(self, value, key, expected): + """ + Simple utility method used for validation of inserted types. + """ + rs = self.session.execute("SELECT {0} from tbl where k={1}".format(value, key)) + retrieved = rs[0]._asdict()[value] + self.assertEqual(expected, retrieved) + + def test_insert_empty_with_string(self): + """ + This tests will attempt to insert a point, polygon, or line, as Empty + @since 3.20 + @jira_ticket PYTHON-481 + @test_category dse geometric + @expected_result EMPTY as a keyword should be honored + """ + uuid_key = uuid1() + self.session.execute("INSERT INTO tbl (k, g) VALUES (%s, %s)", [uuid_key, self.empty_statement]) + self.validate('g', uuid_key, self.empty_value) + + def test_insert_empty_with_object(self): + """ + This tests will attempt to insert a point, polygon, or line, as Empty + @since 3.20 + @jira_ticket PYTHON-481 + @test_category dse geometric + @expected_result EMPTY as a keyword should be used with empty objects + """ + uuid_key = uuid1() + prepared = self.session.prepare("INSERT INTO tbl (k, g) VALUES (?, ?)") + self.session.execute(prepared, (uuid_key, self.empty_value)) + self.validate('g', uuid_key, self.empty_value) + + +@requiredse +class BasicGeometricPointTypeTest(AbstractGeometricTypeTest, BasicGeometricUnitTestCase): + """ + Runs all the geometric tests against PointType + """ + cql_type_name = "'{0}'".format(PointType.typename) + original_value = Point(.5, .13) + + @unittest.skip("Empty String") + def test_insert_empty_with_string(self): + pass + + @unittest.skip("Empty String") + def test_insert_empty_with_object(self): + pass + + +@requiredse +class BasicGeometricLineStringTypeTest(AbstractGeometricTypeTest, BasicGeometricUnitTestCase): + """ + Runs all the geometric tests against LineStringType + """ + cql_type_name = cql_type_name = "'{0}'".format(LineStringType.typename) + original_value = LineString(((1, 2), (3, 4), (9871234, 1235487215))) + empty_statement = 'LINESTRING EMPTY' + empty_value = LineString() + + +@requiredse +class BasicGeometricPolygonTypeTest(AbstractGeometricTypeTest, BasicGeometricUnitTestCase): + """ + Runs all the geometric tests against PolygonType + """ + cql_type_name = cql_type_name = "'{0}'".format(PolygonType.typename) + original_value = Polygon([(10.0, 10.0), (110.0, 10.0), (110., 110.0), (10., 110.0), (10., 10.0)], [[(20., 20.0), (20., 30.0), (30., 30.0), (30., 20.0), (20., 20.0)], [(40., 20.0), (40., 30.0), (50., 30.0), (50., 20.0), (40., 20.0)]]) + empty_statement = 'POLYGON EMPTY' + empty_value = Polygon() diff --git a/tests/integration/advanced/test_spark.py b/tests/integration/advanced/test_spark.py new file mode 100644 index 0000000000..197f99c934 --- /dev/null +++ b/tests/integration/advanced/test_spark.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from cassandra.cluster import EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT +from cassandra.graph import SimpleGraphStatement +from tests.integration import DSE_VERSION, requiredse +from tests.integration.advanced import use_singledc_wth_graph_and_spark, find_spark_master +from tests.integration.advanced.graph import BasicGraphUnitTestCase, ClassicGraphFixtures +log = logging.getLogger(__name__) + + +def setup_module(): + if DSE_VERSION: + use_singledc_wth_graph_and_spark() + + +@requiredse +class SparkLBTests(BasicGraphUnitTestCase): + """ + Test to validate that analytics query can run in a multi-node environment. Also check to ensure + that the master spark node is correctly targeted when OLAP queries are run + + @since 3.20 + @jira_ticket PYTHON-510 + @expected_result OLAP results should come back correctly, master spark coordinator should always be picked. + @test_category dse graph + """ + def test_spark_analytic_query(self): + self.session.execute_graph(ClassicGraphFixtures.classic()) + spark_master = find_spark_master(self.session) + + # Run multiple times to ensure we don't round-robin + for i in range(3): + to_run = SimpleGraphStatement("g.V().count()") + rs = self.session.execute_graph(to_run, execution_profile=EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT) + self.assertEqual(rs[0].value, 7) + self.assertEqual(rs.response_future._current_host.address, spark_master) diff --git a/tests/integration/advanced/test_unixsocketendpoint.py b/tests/integration/advanced/test_unixsocketendpoint.py new file mode 100644 index 0000000000..f2795d1a68 --- /dev/null +++ b/tests/integration/advanced/test_unixsocketendpoint.py @@ -0,0 +1,74 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import unittest + +import time +import subprocess +import logging + +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.connection import UnixSocketEndPoint +from cassandra.policies import WhiteListRoundRobinPolicy, RoundRobinPolicy + +from tests import notwindows +from tests.integration import use_single_node, TestCluster + +log = logging.getLogger() +log.setLevel('DEBUG') + +UNIX_SOCKET_PATH = '/tmp/cass.sock' + + +def setup_module(): + use_single_node() + + +class UnixSocketWhiteListRoundRobinPolicy(WhiteListRoundRobinPolicy): + def __init__(self, hosts): + self._allowed_hosts = self._allowed_hosts_resolved = tuple(hosts) + RoundRobinPolicy.__init__(self) + + +@notwindows +class UnixSocketTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + log.debug("Starting socat...") + cls.proc = subprocess.Popen( + ['socat', + 'UNIX-LISTEN:%s,fork' % UNIX_SOCKET_PATH, + 'TCP:localhost:9042'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + + time.sleep(1) + if cls.proc.poll() is not None: + for line in cls.proc.stdout.readlines(): + log.debug("socat: " + line.decode('utf-8')) + raise Exception("Error while starting socat. Return code: %d" % cls.proc.returncode) + + lbp = UnixSocketWhiteListRoundRobinPolicy([UNIX_SOCKET_PATH]) + ep = ExecutionProfile(load_balancing_policy=lbp) + endpoint = UnixSocketEndPoint(UNIX_SOCKET_PATH) + cls.cluster = TestCluster(contact_points=[endpoint], execution_profiles={EXEC_PROFILE_DEFAULT: ep}) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + cls.proc.terminate() + + def test_unix_socket_connection(self): + s = self.cluster.connect() + s.execute('select * from system.local') diff --git a/tests/integration/cloud/__init__.py b/tests/integration/cloud/__init__.py new file mode 100644 index 0000000000..a6a4ab7a5d --- /dev/null +++ b/tests/integration/cloud/__init__.py @@ -0,0 +1,113 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from cassandra.cluster import Cluster + +import unittest + +import os +import subprocess + +from tests.integration import CLOUD_PROXY_PATH, USE_CASS_EXTERNAL + + +def setup_package(): + if CLOUD_PROXY_PATH and not USE_CASS_EXTERNAL: + start_cloud_proxy() + + +def teardown_package(): + if not USE_CASS_EXTERNAL: + stop_cloud_proxy() + + +class CloudProxyCluster(unittest.TestCase): + + creds_dir = os.path.join(os.path.abspath(CLOUD_PROXY_PATH or ''), 'certs/bundles/') + creds = os.path.join(creds_dir, 'creds-v1.zip') + creds_no_auth = os.path.join(creds_dir, 'creds-v1-wo-creds.zip') + creds_unreachable = os.path.join(creds_dir, 'creds-v1-unreachable.zip') + creds_invalid_ca = os.path.join(creds_dir, 'creds-v1-invalid-ca.zip') + + cluster, connect = None, False + session = None + + @classmethod + def connect(cls, creds, **kwargs): + cloud_config = { + 'secure_connect_bundle': creds, + } + cls.cluster = Cluster(cloud=cloud_config, protocol_version=4, **kwargs) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + def tearDown(self): + if self.cluster: + self.cluster.shutdown() + + +class CloudProxyServer(object): + """ + Class for starting and stopping the proxy (sni_single_endpoint) + """ + + ccm_command = 'docker exec $(docker ps -a -q --filter ancestor=single_endpoint) ccm {}' + + def __init__(self, CLOUD_PROXY_PATH): + self.CLOUD_PROXY_PATH = CLOUD_PROXY_PATH + self.running = False + + def start(self): + return_code = subprocess.call( + ['REQUIRE_CLIENT_CERTIFICATE=true ./run.sh'], + cwd=self.CLOUD_PROXY_PATH, + shell=True) + if return_code != 0: + raise Exception("Error while starting proxy server") + self.running = True + + def stop(self): + if self.is_running(): + subprocess.call( + ["docker kill $(docker ps -a -q --filter ancestor=single_endpoint)"], + shell=True) + self.running = False + + def is_running(self): + return self.running + + def start_node(self, id): + subcommand = 'node{} start --jvm_arg "-Ddse.product_type=DATASTAX_APOLLO" --root --wait-for-binary-proto'.format(id) + subprocess.call( + [self.ccm_command.format(subcommand)], + shell=True) + + def stop_node(self, id): + subcommand = 'node{} stop'.format(id) + subprocess.call( + [self.ccm_command.format(subcommand)], + shell=True) + + +CLOUD_PROXY_SERVER = CloudProxyServer(CLOUD_PROXY_PATH) + + +def start_cloud_proxy(): + """ + Starts and waits for the proxy to run + """ + CLOUD_PROXY_SERVER.stop() + CLOUD_PROXY_SERVER.start() + + +def stop_cloud_proxy(): + CLOUD_PROXY_SERVER.stop() diff --git a/tests/integration/cloud/conftest.py b/tests/integration/cloud/conftest.py new file mode 100644 index 0000000000..6bfda32534 --- /dev/null +++ b/tests/integration/cloud/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from tests.integration.cloud import setup_package, teardown_package + + +@pytest.fixture(scope='session', autouse=True) +def setup_and_teardown_packages(): + setup_package() + yield + teardown_package() diff --git a/tests/integration/cloud/test_cloud.py b/tests/integration/cloud/test_cloud.py new file mode 100644 index 0000000000..1c1e75c1ee --- /dev/null +++ b/tests/integration/cloud/test_cloud.py @@ -0,0 +1,244 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +from cassandra.datastax.cloud import parse_metadata_info +from cassandra.query import SimpleStatement +from cassandra.cqlengine import connection +from cassandra.cqlengine.management import sync_table, create_keyspace_simple +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + +import unittest + +from ssl import SSLContext, PROTOCOL_TLS + +from cassandra import DriverException, ConsistencyLevel, InvalidRequest +from cassandra.cluster import NoHostAvailable, ExecutionProfile, Cluster, _execution_profile_to_string +from cassandra.connection import SniEndPoint +from cassandra.auth import PlainTextAuthProvider +from cassandra.policies import TokenAwarePolicy, DCAwareRoundRobinPolicy, ConstantReconnectionPolicy + +from unittest.mock import patch + +from tests.integration import requirescloudproxy +from tests.util import wait_until_not_raised +from tests.integration.cloud import CloudProxyCluster, CLOUD_PROXY_SERVER + +DISALLOWED_CONSISTENCIES = [ + ConsistencyLevel.ANY, + ConsistencyLevel.ONE, + ConsistencyLevel.LOCAL_ONE +] + + +@requirescloudproxy +class CloudTests(CloudProxyCluster): + def hosts_up(self): + return [h for h in self.cluster.metadata.all_hosts() if h.is_up] + + def test_resolve_and_connect(self): + self.connect(self.creds) + + self.assertEqual(len(self.hosts_up()), 3) + for host in self.cluster.metadata.all_hosts(): + self.assertTrue(host.is_up) + self.assertIsInstance(host.endpoint, SniEndPoint) + self.assertEqual(str(host.endpoint), "{}:{}:{}".format( + host.endpoint.address, host.endpoint.port, host.host_id)) + self.assertIn(host.endpoint._resolved_address, ("127.0.0.1", '::1')) + + def test_match_system_local(self): + self.connect(self.creds) + + self.assertEqual(len(self.hosts_up()), 3) + for host in self.cluster.metadata.all_hosts(): + row = self.session.execute('SELECT * FROM system.local', host=host).one() + self.assertEqual(row.host_id, host.host_id) + self.assertEqual(row.rpc_address, host.broadcast_rpc_address) + + def test_set_auth_provider(self): + self.connect(self.creds) + self.assertIsInstance(self.cluster.auth_provider, PlainTextAuthProvider) + self.assertEqual(self.cluster.auth_provider.username, 'user1') + self.assertEqual(self.cluster.auth_provider.password, 'user1') + + def test_support_leaving_the_auth_unset(self): + with self.assertRaises(NoHostAvailable): + self.connect(self.creds_no_auth) + self.assertIsNone(self.cluster.auth_provider) + + def test_support_overriding_auth_provider(self): + try: + self.connect(self.creds, auth_provider=PlainTextAuthProvider('invalid', 'invalid')) + except: + pass # this will fail soon when sni_single_endpoint is updated + self.assertIsInstance(self.cluster.auth_provider, PlainTextAuthProvider) + self.assertEqual(self.cluster.auth_provider.username, 'invalid') + self.assertEqual(self.cluster.auth_provider.password, 'invalid') + + def test_error_overriding_ssl_context(self): + with self.assertRaises(ValueError) as cm: + self.connect(self.creds, ssl_context=SSLContext(PROTOCOL_TLS)) + + self.assertIn('cannot be specified with a cloud configuration', str(cm.exception)) + + def test_error_overriding_ssl_options(self): + with self.assertRaises(ValueError) as cm: + self.connect(self.creds, ssl_options={'check_hostname': True}) + + self.assertIn('cannot be specified with a cloud configuration', str(cm.exception)) + + def _bad_hostname_metadata(self, config, http_data): + config = parse_metadata_info(config, http_data) + config.sni_host = "127.0.0.1" + return config + + def test_verify_hostname(self): + with patch('cassandra.datastax.cloud.parse_metadata_info', wraps=self._bad_hostname_metadata): + with self.assertRaises(NoHostAvailable) as e: + self.connect(self.creds) + self.assertIn("hostname", str(e.exception).lower()) + + def test_error_when_bundle_doesnt_exist(self): + try: + self.connect('/invalid/path/file.zip') + except Exception as e: + self.assertIsInstance(e, FileNotFoundError) + + def test_load_balancing_policy_is_dcawaretokenlbp(self): + self.connect(self.creds) + self.assertIsInstance(self.cluster.profile_manager.default.load_balancing_policy, + TokenAwarePolicy) + self.assertIsInstance(self.cluster.profile_manager.default.load_balancing_policy._child_policy, + DCAwareRoundRobinPolicy) + + def test_resolve_and_reconnect_on_node_down(self): + + self.connect(self.creds, + idle_heartbeat_interval=1, idle_heartbeat_timeout=1, + reconnection_policy=ConstantReconnectionPolicy(120)) + + self.assertEqual(len(self.hosts_up()), 3) + CLOUD_PROXY_SERVER.stop_node(1) + wait_until_not_raised( + lambda: self.assertEqual(len(self.hosts_up()), 2), + 0.02, 250) + + host = [h for h in self.cluster.metadata.all_hosts() if not h.is_up][0] + with patch.object(SniEndPoint, "resolve", wraps=host.endpoint.resolve) as mocked_resolve: + CLOUD_PROXY_SERVER.start_node(1) + wait_until_not_raised( + lambda: self.assertEqual(len(self.hosts_up()), 3), + 0.02, 250) + mocked_resolve.assert_called() + + def test_metadata_unreachable(self): + with self.assertRaises(DriverException) as cm: + self.connect(self.creds_unreachable, connect_timeout=1) + + self.assertIn('Unable to connect to the metadata service', str(cm.exception)) + + def test_metadata_ssl_error(self): + with self.assertRaises(DriverException) as cm: + self.connect(self.creds_invalid_ca) + + self.assertIn('Unable to connect to the metadata', str(cm.exception)) + + def test_default_consistency(self): + self.connect(self.creds) + self.assertEqual(self.session.default_consistency_level, ConsistencyLevel.LOCAL_QUORUM) + # Verify EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT, + # EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT, EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT + for ep_key in self.cluster.profile_manager.profiles.keys(): + ep = self.cluster.profile_manager.profiles[ep_key] + self.assertEqual( + ep.consistency_level, + ConsistencyLevel.LOCAL_QUORUM, + "Expecting LOCAL QUORUM for profile {}, but got {} instead".format( + _execution_profile_to_string(ep_key), ConsistencyLevel.value_to_name[ep.consistency_level] + )) + + def test_default_consistency_of_execution_profiles(self): + cloud_config = {'secure_connect_bundle': self.creds} + self.cluster = Cluster(cloud=cloud_config, protocol_version=4, execution_profiles={ + 'pre_create_default_ep': ExecutionProfile(), + 'pre_create_changed_ep': ExecutionProfile( + consistency_level=ConsistencyLevel.LOCAL_ONE, + ), + }) + self.cluster.add_execution_profile('pre_connect_default_ep', ExecutionProfile()) + self.cluster.add_execution_profile( + 'pre_connect_changed_ep', + ExecutionProfile( + consistency_level=ConsistencyLevel.LOCAL_ONE, + ) + ) + session = self.cluster.connect(wait_for_all_pools=True) + + self.cluster.add_execution_profile('post_connect_default_ep', ExecutionProfile()) + self.cluster.add_execution_profile( + 'post_connect_changed_ep', + ExecutionProfile( + consistency_level=ConsistencyLevel.LOCAL_ONE, + ) + ) + + for default in ['pre_create_default_ep', 'pre_connect_default_ep', 'post_connect_default_ep']: + cl = self.cluster.profile_manager.profiles[default].consistency_level + self.assertEqual( + cl, ConsistencyLevel.LOCAL_QUORUM, + "Expecting LOCAL QUORUM for profile {}, but got {} instead".format(default, cl) + ) + for changed in ['pre_create_changed_ep', 'pre_connect_changed_ep', 'post_connect_changed_ep']: + cl = self.cluster.profile_manager.profiles[changed].consistency_level + self.assertEqual( + cl, ConsistencyLevel.LOCAL_ONE, + "Expecting LOCAL ONE for profile {}, but got {} instead".format(default, cl) + ) + + def test_consistency_guardrails(self): + self.connect(self.creds) + self.session.execute( + "CREATE KEYSPACE IF NOT EXISTS test_consistency_guardrails " + "with replication={'class': 'SimpleStrategy', 'replication_factor': 1}" + ) + self.session.execute("CREATE TABLE IF NOT EXISTS test_consistency_guardrails.guardrails (id int primary key)") + for consistency in DISALLOWED_CONSISTENCIES: + statement = SimpleStatement( + "INSERT INTO test_consistency_guardrails.guardrails (id) values (1)", + consistency_level=consistency + ) + with self.assertRaises(InvalidRequest) as e: + self.session.execute(statement) + self.assertIn('not allowed for Write Consistency Level', str(e.exception)) + + # Sanity check to make sure we can do a normal insert + statement = SimpleStatement( + "INSERT INTO test_consistency_guardrails.guardrails (id) values (1)", + consistency_level=ConsistencyLevel.LOCAL_QUORUM + ) + try: + self.session.execute(statement) + except InvalidRequest: + self.fail("InvalidRequest was incorrectly raised for write query at LOCAL QUORUM!") + + def test_cqlengine_can_connect(self): + class TestModel(Model): + id = columns.Integer(primary_key=True) + val = columns.Text() + + connection.setup(None, "test", cloud={'secure_connect_bundle': self.creds}) + create_keyspace_simple('test', 1) + sync_table(TestModel) + TestModel.objects.create(id=42, value='test') + self.assertEqual(len(TestModel.objects.all()), 1) diff --git a/tests/integration/cloud/test_cloud_schema.py b/tests/integration/cloud/test_cloud_schema.py new file mode 100644 index 0000000000..8dff49508a --- /dev/null +++ b/tests/integration/cloud/test_cloud_schema.py @@ -0,0 +1,118 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" +This is mostly copypasta from integration/long/test_schema.py + +TODO: Come up with way to run cloud and local tests without duplication +""" + +import logging +import time + +from cassandra import ConsistencyLevel +from cassandra.cluster import Cluster +from cassandra.query import SimpleStatement + +from tests.integration import execute_until_pass +from tests.integration.cloud import CloudProxyCluster + +log = logging.getLogger(__name__) + + +class CloudSchemaTests(CloudProxyCluster): + def test_recreates(self): + """ + Basic test for repeated schema creation and use, using many different keyspaces + """ + self.connect(self.creds) + session = self.session + + for _ in self.cluster.metadata.all_hosts(): + for keyspace_number in range(5): + keyspace = "ks_{0}".format(keyspace_number) + + if keyspace in self.cluster.metadata.keyspaces.keys(): + drop = "DROP KEYSPACE {0}".format(keyspace) + log.debug(drop) + execute_until_pass(session, drop) + + create = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 3}}".format( + keyspace) + log.debug(create) + execute_until_pass(session, create) + + create = "CREATE TABLE {0}.cf (k int PRIMARY KEY, i int)".format(keyspace) + log.debug(create) + execute_until_pass(session, create) + + use = "USE {0}".format(keyspace) + log.debug(use) + execute_until_pass(session, use) + + insert = "INSERT INTO cf (k, i) VALUES (0, 0)" + log.debug(insert) + ss = SimpleStatement(insert, consistency_level=ConsistencyLevel.QUORUM) + execute_until_pass(session, ss) + + def test_for_schema_disagreement_attribute(self): + """ + Tests to ensure that schema disagreement is properly surfaced on the response future. + + Creates and destroys keyspaces/tables with various schema agreement timeouts set. + First part runs cql create/drop cmds with schema agreement set in such away were it will be impossible for agreement to occur during timeout. + It then validates that the correct value is set on the result. + Second part ensures that when schema agreement occurs, that the result set reflects that appropriately + + @since 3.1.0 + @jira_ticket PYTHON-458 + @expected_result is_schema_agreed is set appropriately on response thefuture + + @test_category schema + """ + # This should yield a schema disagreement + cloud_config = {'secure_connect_bundle': self.creds} + cluster = Cluster(max_schema_agreement_wait=0.001, protocol_version=4, cloud=cloud_config) + session = cluster.connect(wait_for_all_pools=True) + + rs = session.execute( + "CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + self.check_and_wait_for_agreement(session, rs, False) + rs = session.execute( + SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) + self.check_and_wait_for_agreement(session, rs, False) + rs = session.execute("DROP KEYSPACE test_schema_disagreement") + self.check_and_wait_for_agreement(session, rs, False) + cluster.shutdown() + + # These should have schema agreement + cluster = Cluster(protocol_version=4, max_schema_agreement_wait=100, cloud=cloud_config) + session = cluster.connect() + rs = session.execute( + "CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") + self.check_and_wait_for_agreement(session, rs, True) + rs = session.execute( + SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) + self.check_and_wait_for_agreement(session, rs, True) + rs = session.execute("DROP KEYSPACE test_schema_disagreement") + self.check_and_wait_for_agreement(session, rs, True) + cluster.shutdown() + + def check_and_wait_for_agreement(self, session, rs, expected): + # Wait for RESULT_KIND_SCHEMA_CHANGE message to arrive + time.sleep(1) + self.assertEqual(rs.response_future.is_schema_agreed, expected) + if not rs.response_future.is_schema_agreed: + session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) \ No newline at end of file diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000000..e17ac302c8 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from tests.integration import teardown_package + + +@pytest.fixture(scope='session', autouse=True) +def setup_and_teardown_packages(): + print('setup') + yield + teardown_package() diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py index 2950984f58..5148d6417f 100644 --- a/tests/integration/cqlengine/__init__.py +++ b/tests/integration/cqlengine/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,15 +16,21 @@ import os import warnings +import unittest from cassandra import ConsistencyLevel from cassandra.cqlengine import connection -from cassandra.cqlengine.management import create_keyspace_simple, CQLENG_ALLOW_SCHEMA_MANAGEMENT +from cassandra.cqlengine.management import create_keyspace_simple, drop_keyspace, CQLENG_ALLOW_SCHEMA_MANAGEMENT +import cassandra + +from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION, CASSANDRA_IP, ALLOW_BETA_PROTOCOL -from tests.integration import get_server_versions, use_single_node, PROTOCOL_VERSION DEFAULT_KEYSPACE = 'cqlengine_test' +CQL_SKIP_EXECUTE = bool(os.getenv('CQL_SKIP_EXECUTE', False)) + + def setup_package(): warnings.simplefilter('always') # for testing warnings, make sure all are let through os.environ[CQLENG_ALLOW_SCHEMA_MANAGEMENT] = '1' @@ -33,13 +41,76 @@ def setup_package(): create_keyspace_simple(DEFAULT_KEYSPACE, 1) +def teardown_package(): + connection.unregister_connection("default") + def is_prepend_reversed(): # do we have https://issues.apache.org/jira/browse/CASSANDRA-8733 ? ver, _ = get_server_versions() return not (ver >= (2, 0, 13) or ver >= (2, 1, 3)) + def setup_connection(keyspace_name): - connection.setup(['127.0.0.1'], + connection.setup([CASSANDRA_IP], consistency=ConsistencyLevel.ONE, protocol_version=PROTOCOL_VERSION, + allow_beta_protocol_version=ALLOW_BETA_PROTOCOL, default_keyspace=keyspace_name) + + +class StatementCounter(object): + """ + Simple python object used to hold a count of the number of times + the wrapped method has been invoked + """ + def __init__(self, patched_func): + self.func = patched_func + self.counter = 0 + + def wrapped_execute(self, *args, **kwargs): + self.counter += 1 + return self.func(*args, **kwargs) + + def get_counter(self): + return self.counter + + +def execute_count(expected): + """ + A decorator used wrap cassandra.cqlengine.connection.execute. It counts the number of times this method is invoked + then compares it to the number expected. If they don't match it throws an assertion error. + This function can be disabled by running the test harness with the env variable CQL_SKIP_EXECUTE=1 set + """ + def innerCounter(fn): + def wrapped_function(*args, **kwargs): + # Create a counter monkey patch into cassandra.cqlengine.connection.execute + count = StatementCounter(cassandra.cqlengine.connection.execute) + original_function = cassandra.cqlengine.connection.execute + # Monkey patch in our StatementCounter wrapper + cassandra.cqlengine.connection.execute = count.wrapped_execute + # Invoked the underlying unit test + to_return = fn(*args, **kwargs) + # Get the count from our monkey patched counter + count.get_counter() + # DeMonkey Patch our code + cassandra.cqlengine.connection.execute = original_function + # Check to see if we have a pre-existing test case to work from. + if len(args) is 0: + test_case = unittest.TestCase("__init__") + else: + test_case = args[0] + # Check to see if the count is what you expect + test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter())) + return to_return + # Name of the wrapped function must match the original or unittest will error out. + wrapped_function.__name__ = fn.__name__ + wrapped_function.__doc__ = fn.__doc__ + # Escape hatch + if(CQL_SKIP_EXECUTE): + return fn + else: + return wrapped_function + + return innerCounter + + diff --git a/tests/integration/cqlengine/advanced/__init__.py b/tests/integration/cqlengine/advanced/__init__.py new file mode 100644 index 0000000000..588a655d98 --- /dev/null +++ b/tests/integration/cqlengine/advanced/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/integration/cqlengine/advanced/test_cont_paging.py b/tests/integration/cqlengine/advanced/test_cont_paging.py new file mode 100644 index 0000000000..82b0818fae --- /dev/null +++ b/tests/integration/cqlengine/advanced/test_cont_paging.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import unittest + +from packaging.version import Version + +from cassandra.cluster import (EXEC_PROFILE_DEFAULT, + ContinuousPagingOptions, ExecutionProfile, + ProtocolVersion) +from cassandra.cqlengine import columns, connection, models +from cassandra.cqlengine.management import drop_table, sync_table +from tests.integration import (DSE_VERSION, greaterthanorequaldse51, + greaterthanorequaldse60, requiredse, TestCluster) + + +class TestMultiKeyModel(models.Model): + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False) + + +def setup_module(): + if DSE_VERSION: + sync_table(TestMultiKeyModel) + for i in range(1000): + TestMultiKeyModel.create(partition=i, cluster=i, count=5, text="text to write") + + +def teardown_module(): + if DSE_VERSION: + drop_table(TestMultiKeyModel) + + +@requiredse +class BasicConcurrentTests(): + required_dse_version = None + protocol_version = None + connections = set() + sane_connections = {"CONTDEFAULT"} + + @classmethod + def setUpClass(cls): + if DSE_VERSION: + cls._create_cluster_with_cp_options("CONTDEFAULT", ContinuousPagingOptions()) + cls._create_cluster_with_cp_options("ONEPAGE", ContinuousPagingOptions(max_pages=1)) + cls._create_cluster_with_cp_options("MANYPAGES", ContinuousPagingOptions(max_pages=10)) + cls._create_cluster_with_cp_options("SLOW", ContinuousPagingOptions(max_pages_per_second=1)) + + @classmethod + def tearDownClass(cls): + if not DSE_VERSION or DSE_VERSION < cls.required_dse_version: + return + + cls.cluster_default.shutdown() + connection.set_default_connection("default") + + @classmethod + def _create_cluster_with_cp_options(cls, name, cp_options): + execution_profiles = {EXEC_PROFILE_DEFAULT: + ExecutionProfile(continuous_paging_options=cp_options)} + cls.cluster_default = TestCluster(protocol_version=cls.protocol_version, + execution_profiles=execution_profiles) + cls.session_default = cls.cluster_default.connect(wait_for_all_pools=True) + connection.register_connection(name, default=True, session=cls.session_default) + cls.connections.add(name) + + def test_continuous_paging_basic(self): + """ + Test to ensure that various continuous paging works with cqlengine + for session + @since DSE 2.4 + @jira_ticket PYTHON-872 + @expected_result various continous paging options should fetch all the results + + @test_category queries + """ + for connection_name in self.sane_connections: + connection.set_default_connection(connection_name) + row = TestMultiKeyModel.get(partition=0, cluster=0) + self.assertEqual(row.partition, 0) + self.assertEqual(row.cluster, 0) + rows = TestMultiKeyModel.objects().allow_filtering() + self.assertEqual(len(rows), 1000) + + def test_fetch_size(self): + """ + Test to ensure that various continuous paging works with different fetch sizes + for session + @since DSE 2.4 + @jira_ticket PYTHON-872 + @expected_result various continous paging options should fetch all the results + + @test_category queries + """ + for connection_name in self.connections: + conn = connection._connections[connection_name] + initial_default = conn.session.default_fetch_size + self.addCleanup( + setattr, + conn.session, + "default_fetch_size", + initial_default + ) + + connection.set_default_connection("ONEPAGE") + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 150): + connection._connections["ONEPAGE"].session.default_fetch_size = fetch_size + rows = TestMultiKeyModel.objects().allow_filtering() + self.assertEqual(fetch_size, len(rows)) + + connection.set_default_connection("MANYPAGES") + for fetch_size in (2, 3, 7, 10, 15): + connection._connections["MANYPAGES"].session.default_fetch_size = fetch_size + rows = TestMultiKeyModel.objects().allow_filtering() + self.assertEqual(fetch_size * 10, len(rows)) + + for connection_name in self.sane_connections: + connection.set_default_connection(connection_name) + for fetch_size in (2, 3, 7, 10, 99, 100, 101, 150): + connection._connections[connection_name].session.default_fetch_size = fetch_size + rows = TestMultiKeyModel.objects().allow_filtering() + self.assertEqual(1000, len(rows)) + + +@requiredse +@greaterthanorequaldse51 +class ContPagingTestsDSEV1(BasicConcurrentTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + BasicConcurrentTests.required_dse_version = Version('5.1') + if not DSE_VERSION or DSE_VERSION < BasicConcurrentTests.required_dse_version: + return + + BasicConcurrentTests.protocol_version = ProtocolVersion.DSE_V1 + BasicConcurrentTests.setUpClass() + +@requiredse +@greaterthanorequaldse60 +class ContPagingTestsDSEV2(BasicConcurrentTests, unittest.TestCase): + @classmethod + def setUpClass(cls): + BasicConcurrentTests.required_dse_version = Version('6.0') + if not DSE_VERSION or DSE_VERSION < BasicConcurrentTests.required_dse_version: + return + BasicConcurrentTests.protocol_version = ProtocolVersion.DSE_V2 + BasicConcurrentTests.setUpClass() + + cls.connections = cls.connections.union({"SMALL_QUEUE", "BIG_QUEUE"}) + cls.sane_connections = cls.sane_connections.union({"SMALL_QUEUE", "BIG_QUEUE"}) + + cls._create_cluster_with_cp_options("SMALL_QUEUE", ContinuousPagingOptions(max_queue_size=2)) + cls._create_cluster_with_cp_options("BIG_QUEUE", ContinuousPagingOptions(max_queue_size=400)) diff --git a/tests/integration/cqlengine/base.py b/tests/integration/cqlengine/base.py index 5536efb3d2..1b99005fc4 100644 --- a/tests/integration/cqlengine/base.py +++ b/tests/integration/cqlengine/base.py @@ -1,25 +1,37 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest import sys from cassandra.cqlengine.connection import get_session +from cassandra.cqlengine.models import Model +from cassandra.cqlengine import columns + +from uuid import uuid4 + +class TestQueryUpdateModel(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + text_set = columns.Set(columns.Text, required=False) + text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) class BaseCassEngTestCase(unittest.TestCase): diff --git a/tests/integration/cqlengine/columns/__init__.py b/tests/integration/cqlengine/columns/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/integration/cqlengine/columns/__init__.py +++ b/tests/integration/cqlengine/columns/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/columns/test_container_columns.py b/tests/integration/cqlengine/columns/test_container_columns.py index b2034bb11b..abdbb6185b 100644 --- a/tests/integration/cqlengine/columns/test_container_columns.py +++ b/tests/integration/cqlengine/columns/test_container_columns.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,16 +17,18 @@ from datetime import datetime, timedelta import json import logging -import six import sys import traceback from uuid import uuid4 +from packaging.version import Version -from cassandra import WriteTimeout +from cassandra import WriteTimeout, OperationTimedOut import cassandra.cqlengine.columns as columns from cassandra.cqlengine.functions import get_total_seconds from cassandra.cqlengine.models import Model, ValidationError from cassandra.cqlengine.management import sync_table, drop_table + +from tests.integration import CASSANDRA_IP from tests.integration.cqlengine import is_prepend_reversed from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration import greaterthancass20, CASSANDRA_VERSION @@ -45,7 +49,7 @@ class JsonTestColumn(columns.Column): def to_python(self, value): if value is None: return - if isinstance(value, six.string_types): + if isinstance(value, str): return json.loads(value) else: return value @@ -134,8 +138,11 @@ def test_element_count_validation(self): break except WriteTimeout: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb + except OperationTimedOut: + #This will happen if the host is remote + self.assertFalse(CASSANDRA_IP.startswith("127.0.0.")) self.assertRaises(ValidationError, TestSetModel.create, **{'text_set': set(str(uuid4()) for i in range(65536))}) def test_partial_updates(self): @@ -247,7 +254,7 @@ def test_element_count_validation(self): break except WriteTimeout: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb self.assertRaises(ValidationError, TestListModel.create, **{'text_list': [str(uuid4()) for _ in range(65536)]}) @@ -417,7 +424,7 @@ def test_element_count_validation(self): break except WriteTimeout: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb self.assertRaises(ValidationError, TestMapModel.create, **{'text_map': dict((str(uuid4()), i) for i in range(65536))}) @@ -554,7 +561,7 @@ class TestTupleColumn(BaseCassEngTestCase): @classmethod def setUpClass(cls): # Skip annotations don't seem to skip class level teradown and setup methods - if(CASSANDRA_VERSION >= '2.1'): + if CASSANDRA_VERSION >= Version('2.1'): drop_table(TestTupleModel) sync_table(TestTupleModel) @@ -754,7 +761,7 @@ class TestNestedType(BaseCassEngTestCase): @classmethod def setUpClass(cls): # Skip annotations don't seem to skip class level teradown and setup methods - if(CASSANDRA_VERSION >= '2.1'): + if CASSANDRA_VERSION >= Version('2.1'): drop_table(TestNestedModel) sync_table(TestNestedModel) diff --git a/tests/integration/cqlengine/columns/test_counter_column.py b/tests/integration/cqlengine/columns/test_counter_column.py index 11e5c1bd91..e68af62050 100644 --- a/tests/integration/cqlengine/columns/test_counter_column.py +++ b/tests/integration/cqlengine/columns/test_counter_column.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -112,3 +114,19 @@ def test_new_instance_defaults_to_zero(self): instance = TestCounterModel() assert instance.counter == 0 + def test_save_after_no_update(self): + expected_value = 15 + instance = TestCounterModel.create() + instance.update(counter=expected_value) + + # read back + instance = TestCounterModel.get(partition=instance.partition) + self.assertEqual(instance.counter, expected_value) + + # save after doing nothing + instance.save() + self.assertEqual(instance.counter, expected_value) + + # make sure there was no increment + instance = TestCounterModel.get(partition=instance.partition) + self.assertEqual(instance.counter, expected_value) diff --git a/tests/integration/cqlengine/columns/test_static_column.py b/tests/integration/cqlengine/columns/test_static_column.py index 543dc84732..8d16ec6227 100644 --- a/tests/integration/cqlengine/columns/test_static_column.py +++ b/tests/integration/cqlengine/columns/test_static_column.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from uuid import uuid4 diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index 7b51080bff..48ae74b5ab 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,33 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from datetime import datetime, timedelta, date, tzinfo +import sys +from datetime import datetime, timedelta, date, tzinfo, time, timezone from decimal import Decimal as D from uuid import uuid4, uuid1 +from packaging.version import Version from cassandra import InvalidRequest -from cassandra.cqlengine.columns import TimeUUID -from cassandra.cqlengine.columns import Text -from cassandra.cqlengine.columns import Integer -from cassandra.cqlengine.columns import BigInt -from cassandra.cqlengine.columns import VarInt -from cassandra.cqlengine.columns import DateTime -from cassandra.cqlengine.columns import Date -from cassandra.cqlengine.columns import UUID -from cassandra.cqlengine.columns import Boolean -from cassandra.cqlengine.columns import Decimal -from cassandra.cqlengine.columns import Inet +from cassandra.cqlengine.columns import (TimeUUID, Ascii, Text, Integer, BigInt, + VarInt, DateTime, Date, UUID, Boolean, + Decimal, Inet, Time, UserDefinedType, + Map, List, Set, Tuple, Double, Duration) from cassandra.cqlengine.connection import execute from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.models import Model, ValidationError +from cassandra.cqlengine.usertype import UserType from cassandra import util -from tests.integration import PROTOCOL_VERSION +from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 from tests.integration.cqlengine.base import BaseCassEngTestCase @@ -60,7 +55,7 @@ def test_datetime_io(self): now = datetime.now() self.DatetimeTest.objects.create(test_id=0, created_at=now) dt2 = self.DatetimeTest.objects(test_id=0).first() - assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6] + self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6]) def test_datetime_tzinfo_io(self): class TZ(tzinfo): @@ -72,21 +67,28 @@ def dst(self, date_time): now = datetime(1982, 1, 1, tzinfo=TZ()) dt = self.DatetimeTest.objects.create(test_id=1, created_at=now) dt2 = self.DatetimeTest.objects(test_id=1).first() - assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] + self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6]) + @greaterthanorequalcass30 def test_datetime_date_support(self): today = date.today() self.DatetimeTest.objects.create(test_id=2, created_at=today) dt2 = self.DatetimeTest.objects(test_id=2).first() - assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat() + self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat()) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + + result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first() + self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) def test_datetime_none(self): dt = self.DatetimeTest.objects.create(test_id=3, created_at=None) dt2 = self.DatetimeTest.objects(test_id=3).first() - assert dt2.created_at is None + self.assertIsNone(dt2.created_at) dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at') - assert dts[0][0] is None + self.assertIsNone(dts[0][0]) def test_datetime_invalid(self): dt_value= 'INVALID' @@ -97,13 +99,36 @@ def test_datetime_timestamp(self): dt_value = 1454520554 self.DatetimeTest.objects.create(test_id=5, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=5).first() - assert dt2.created_at == datetime.utcfromtimestamp(dt_value) + self.assertEqual(dt2.created_at, datetime.fromtimestamp(dt_value, tz=timezone.utc).replace(tzinfo=None)) def test_datetime_large(self): dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000) self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=6).first() - assert dt2.created_at == dt_value + self.assertEqual(dt2.created_at, dt_value) + + def test_datetime_truncate_microseconds(self): + """ + Test to ensure that truncate microseconds works as expected. + This will be default behavior in the future and we will need to modify the tests to comply + with new behavior + + @since 3.2 + @jira_ticket PYTHON-273 + @expected_result microseconds should be to the nearest thousand when truncate is set. + + @test_category object_mapper + """ + DateTime.truncate_microseconds = True + try: + dt_value = datetime(2024, 12, 31, 10, 10, 10, 923567) + dt_truncated = datetime(2024, 12, 31, 10, 10, 10, 923000) + self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) + dt2 = self.DatetimeTest.objects(test_id=6).first() + self.assertEqual(dt2.created_at,dt_truncated) + finally: + # We need to always return behavior to default + DateTime.truncate_microseconds = False class TestBoolDefault(BaseCassEngTestCase): @@ -122,6 +147,7 @@ def test_default_is_set(self): tmp2 = self.BoolDefaultValueTest.get(test_id=1) self.assertEqual(True, tmp2.stuff) + class TestBoolValidation(BaseCassEngTestCase): class BoolValidationTest(Model): @@ -138,6 +164,7 @@ def test_validation_preserves_none(self): test_obj.validate() self.assertIsNone(test_obj.bool_column) + class TestVarInt(BaseCassEngTestCase): class VarIntTest(Model): @@ -160,49 +187,238 @@ def test_varint_io(self): int2 = self.VarIntTest.objects(test_id=0).first() self.assertEqual(int1.bignum, int2.bignum) + with self.assertRaises(ValidationError): + self.VarIntTest.objects.create(test_id=0, bignum="not_a_number") -class TestDate(BaseCassEngTestCase): - class DateTest(Model): - - test_id = Integer(primary_key=True) - created_at = Date() +class DataType(): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION < 4: + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): return - sync_table(cls.DateTest) + + class DataTypeTest(Model): + test_id = Integer(primary_key=True) + class_param = cls.db_klass() + + cls.model_class = DataTypeTest + sync_table(cls.model_class) @classmethod def tearDownClass(cls): - if PROTOCOL_VERSION < 4: + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): return - drop_table(cls.DateTest) + drop_table(cls.model_class) def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + raise unittest.SkipTest("Protocol v4 datatypes " + "require native protocol 4+ and C* version >=3.0, " + "currently using protocol {0} and C* version {1}". + format(PROTOCOL_VERSION, CASSANDRA_VERSION)) - def test_date_io(self): - today = date.today() - self.DateTest.objects.create(test_id=0, created_at=today) - result = self.DateTest.objects(test_id=0).first() - self.assertEqual(result.created_at, util.Date(today)) + def _check_value_is_correct_in_db(self, value): + """ + Check that different ways of reading the value + from the model class give the same expected result + """ + if value is None: + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsNone(result.class_param) + + result = self.model_class.objects(test_id=0).first() + self.assertIsNone(result.class_param) + + else: + if not isinstance(value, self.python_klass): + value_to_compare = self.python_klass(value) + else: + value_to_compare = value + + result = self.model_class.objects(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first() + self.assertIsInstance(result.class_param, self.python_klass) + self.assertEqual(result.class_param, value_to_compare) + + return result + + def test_param_io(self): + first_value = self.first_value + second_value = self.second_value + third_value = self.third_value + + # Check value is correctly written/read from the DB + self.model_class.objects.create(test_id=0, class_param=first_value) + result = self._check_value_is_correct_in_db(first_value) + result.delete() + + # Check the previous value has been correctly deleted and write a new value + self.model_class.objects.create(test_id=0, class_param=second_value) + result = self._check_value_is_correct_in_db(second_value) + + # Check the value can be correctly updated from the Model class + result.update(class_param=third_value).save() + result = self._check_value_is_correct_in_db(third_value) + + # Check None is correctly written to the DB + result.update(class_param=None).save() + self._check_value_is_correct_in_db(None) + + def test_param_none(self): + """ + Test that None value is correctly written to the db + and then is correctly read + """ + self.model_class.objects.create(test_id=1, class_param=None) + dt2 = self.model_class.objects(test_id=1).first() + self.assertIsNone(dt2.class_param) + + dts = self.model_class.objects(test_id=1).values_list('class_param') + self.assertIsNone(dts[0][0]) + + +class TestDate(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Date, + util.Date + ) + + cls.first_value, cls.second_value, cls.third_value = ( + datetime.utcnow(), + util.Date(datetime(1, 1, 1)), + datetime(1, 1, 2) + ) + super(TestDate, cls).setUpClass() + + +class TestTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Time, + util.Time + ) + cls.first_value, cls.second_value, cls.third_value = ( + None, + util.Time(time(2, 12, 7, 49)), + time(2, 12, 7, 50) + ) + super(TestTime, cls).setUpClass() + + +class TestDateTime(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + DateTime, + datetime + ) + cls.first_value, cls.second_value, cls.third_value = ( + datetime(2017, 4, 13, 18, 34, 24, 317000), + datetime(1, 1, 1), + datetime(1, 1, 2) + ) + super(TestDateTime, cls).setUpClass() + + +class TestBoolean(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + cls.db_klass, cls.python_klass = ( + Boolean, + bool + ) + cls.first_value, cls.second_value, cls.third_value = ( + None, + False, + True + ) + super(TestBoolean, cls).setUpClass() + +@greaterthanorequalcass3_11 +class TestDuration(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + # setUpClass is executed despite the whole class being skipped + if CASSANDRA_VERSION >= Version("3.10"): + cls.db_klass, cls.python_klass = ( + Duration, + util.Duration + ) + cls.first_value, cls.second_value, cls.third_value = ( + util.Duration(0, 0, 0), + util.Duration(1, 2, 3), + util.Duration(0, 0, 0) + ) + super(TestDuration, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + if CASSANDRA_VERSION >= Version("3.10"): + super(TestDuration, cls).tearDownClass() + + +class User(UserType): + # We use Date and Time to ensure to_python + # is called for these columns + age = Integer() + date_param = Date() + map_param = Map(Integer, Time) + list_param = List(Date) + set_param = Set(Date) + tuple_param = Tuple(Date, Decimal, Boolean, VarInt, Double, UUID) - def test_date_io_using_datetime(self): - now = datetime.utcnow() - self.DateTest.objects.create(test_id=0, created_at=now) - result = self.DateTest.objects(test_id=0).first() - self.assertIsInstance(result.created_at, util.Date) - self.assertEqual(result.created_at, util.Date(now)) - def test_date_none(self): - self.DateTest.objects.create(test_id=1, created_at=None) - dt2 = self.DateTest.objects(test_id=1).first() - assert dt2.created_at is None +class UserModel(Model): + test_id = Integer(primary_key=True) + class_param = UserDefinedType(User) - dts = self.DateTest.objects(test_id=1).values_list('created_at') - assert dts[0][0] is None + +class TestUDT(DataType, BaseCassEngTestCase): + @classmethod + def setUpClass(cls): + if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + return + + cls.db_klass, cls.python_klass = UserDefinedType, User + cls.first_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 3)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 1)))), + tuple_param=(datetime(1, 1, 3), 2, False, 1, 2.324, uuid4()) + ) + + cls.second_value = User( + age=1, + date_param=datetime.utcnow(), + map_param={1: time(2, 12, 7, 50), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 2, 3)], + set_param=None, + tuple_param=(datetime(1, 1, 2), 2, False, 1, 2.324, uuid4()) + ) + + cls.third_value = User( + age=2, + date_param=None, + map_param={1: time(2, 12, 7, 51), 2: util.Time(time(2, 12, 7, 49))}, + list_param=[datetime(1, 1, 2), datetime(1, 1, 4)], + set_param=set((datetime(1, 1, 3), util.Date(datetime(1, 1, 2)))), + tuple_param=(None, 3, False, None, 2.3214, uuid4()) + ) + + cls.model_class = UserModel + sync_table(cls.model_class) class TestDecimal(BaseCassEngTestCase): @@ -228,6 +444,7 @@ def test_decimal_io(self): dt2 = self.DecimalTest.objects(test_id=0).first() assert dt2.dec_val == D('5') + class TestUUID(BaseCassEngTestCase): class UUIDTest(Model): @@ -261,6 +478,7 @@ def test_uuid_with_upcase(self): t1 = self.UUIDTest.get(test_id=0) assert a_uuid == t1.a_uuid + class TestTimeUUID(BaseCassEngTestCase): class TimeUUIDTest(Model): @@ -285,17 +503,19 @@ def test_timeuuid_io(self): assert t1.timeuuid.time == t1.timeuuid.time + class TestInteger(BaseCassEngTestCase): class IntegerTest(Model): test_id = UUID(primary_key=True, default=lambda:uuid4()) - value = Integer(default=0, required=True) + value = Integer(default=0, required=True) def test_default_zero_fields_validate(self): """ Tests that integer columns with a default value of 0 validate """ it = self.IntegerTest() it.validate() + class TestBigInt(BaseCassEngTestCase): class BigIntTest(Model): @@ -307,51 +527,254 @@ def test_default_zero_fields_validate(self): it = self.BigIntTest() it.validate() -class TestText(BaseCassEngTestCase): +class TestAscii(BaseCassEngTestCase): def test_min_length(self): - # not required defaults to 0 - col = Text() - col.validate('') - col.validate('b') + """ Test arbitrary minimal lengths requirements. """ + + Ascii(min_length=0).validate('') + Ascii(min_length=0, required=True).validate('') + + Ascii(min_length=0).validate(None) + Ascii(min_length=0).validate('kevin') + + Ascii(min_length=1).validate('k') + + Ascii(min_length=5).validate('kevin') + Ascii(min_length=5).validate('kevintastic') - # required defaults to 1 with self.assertRaises(ValidationError): - Text(required=True).validate('') + Ascii(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(min_length=-1) + + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Ascii(max_length=0).validate('') + Ascii(max_length=0).validate(None) + + Ascii(max_length=1).validate('') + Ascii(max_length=1).validate(None) + Ascii(max_length=1).validate('b') + + Ascii(max_length=5).validate('') + Ascii(max_length=5).validate(None) + Ascii(max_length=5).validate('b') + Ascii(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Ascii(max_length=0).validate('b') + + with self.assertRaises(ValidationError): + Ascii(max_length=5).validate('blaketastic') + + with self.assertRaises(ValueError): + Ascii(max_length=-1) + + def test_length_range(self): + Ascii(min_length=0, max_length=0) + Ascii(min_length=0, max_length=1) + Ascii(min_length=10, max_length=10) + Ascii(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Ascii(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Ascii(min_length=1, max_length=0) + + def test_type_checking(self): + Ascii().validate('string') + Ascii().validate(u'unicode') + Ascii().validate(bytearray('bytearray', encoding='ascii')) + + with self.assertRaises(ValidationError): + Ascii().validate(5) + + with self.assertRaises(ValidationError): + Ascii().validate(True) + + Ascii().validate("!#$%&\'()*+,-./") + + with self.assertRaises(ValidationError): + Ascii().validate('Beyonc' + chr(233)) + + if sys.version_info < (3, 1): + with self.assertRaises(ValidationError): + Ascii().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Ascii().validate(''), '') + self.assertEqual(Ascii().validate(None), None) + self.assertEqual(Ascii().validate('yo'), 'yo') + + def test_non_required_validation(self): + """ Tests that validation is ok on none and blank values if required is False. """ + Ascii().validate('') + Ascii().validate(None) + + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Ascii(required=True).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate('') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate(None) + + # With min_length set. + Ascii(required=True, min_length=0).validate('k') + Ascii(required=True, min_length=1).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True, min_length=2).validate('k') + + # With max_length set. + Ascii(required=True, max_length=1).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True, max_length=2).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(required=True, max_length=0) + + +class TestText(BaseCassEngTestCase): + + def test_min_length(self): + """ Test arbitrary minimal lengths requirements. """ - #test arbitrary lengths Text(min_length=0).validate('') + Text(min_length=0, required=True).validate('') + + Text(min_length=0).validate(None) + Text(min_length=0).validate('blake') + + Text(min_length=1).validate('b') + Text(min_length=5).validate('blake') Text(min_length=5).validate('blaketastic') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Text(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=6).validate(None) + with self.assertRaises(ValidationError): Text(min_length=6).validate('blake') + with self.assertRaises(ValueError): + Text(min_length=-1) + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Text(max_length=0).validate('') + Text(max_length=0).validate(None) + + Text(max_length=1).validate('') + Text(max_length=1).validate(None) + Text(max_length=1).validate('b') + Text(max_length=5).validate('') + Text(max_length=5).validate(None) + Text(max_length=5).validate('b') Text(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Text(max_length=0).validate('b') + with self.assertRaises(ValidationError): Text(max_length=5).validate('blaketastic') + with self.assertRaises(ValueError): + Text(max_length=-1) + + def test_length_range(self): + Text(min_length=0, max_length=0) + Text(min_length=0, max_length=1) + Text(min_length=10, max_length=10) + Text(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Text(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Text(min_length=1, max_length=0) + def test_type_checking(self): Text().validate('string') Text().validate(u'unicode') Text().validate(bytearray('bytearray', encoding='ascii')) - with self.assertRaises(ValidationError): - Text(required=True).validate(None) - with self.assertRaises(ValidationError): Text().validate(5) with self.assertRaises(ValidationError): Text().validate(True) + Text().validate("!#$%&\'()*+,-./") + Text().validate('Beyonc' + chr(233)) + if sys.version_info < (3, 1): + Text().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Text().validate(''), '') + self.assertEqual(Text().validate(None), None) + self.assertEqual(Text().validate('yo'), 'yo') + def test_non_required_validation(self): """ Tests that validation is ok on none and blank values if required is False """ Text().validate('') Text().validate(None) + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Text(required=True).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True).validate('') + + with self.assertRaises(ValidationError): + Text(required=True).validate(None) + + # With min_length set. + Text(required=True, min_length=0).validate('b') + Text(required=True, min_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, min_length=2).validate('b') + + # With max_length set. + Text(required=True, max_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, max_length=2).validate('blake') + with self.assertRaises(ValueError): + Text(required=True, max_length=0) class TestExtraFieldsRaiseException(BaseCassEngTestCase): @@ -363,6 +786,7 @@ def test_extra_field(self): with self.assertRaises(ValidationError): self.TestModel.create(bacon=5000) + class TestPythonDoesntDieWhenExtraFieldIsInCassandra(BaseCassEngTestCase): class TestModel(Model): @@ -374,7 +798,8 @@ def test_extra_field(self): sync_table(self.TestModel) self.TestModel.create() execute("ALTER TABLE {0} add blah int".format(self.TestModel.column_family_name(include_keyspace=True))) - self.TestModel.objects().all() + self.TestModel.objects.all() + class TestTimeUUIDFromDatetime(BaseCassEngTestCase): def test_conversion_specific_date(self): @@ -386,11 +811,12 @@ def test_conversion_specific_date(self): assert isinstance(uuid, UUID) ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp - new_dt = datetime.utcfromtimestamp(ts) + new_dt = datetime.fromtimestamp(ts, tz=timezone.utc).replace(tzinfo=None) # checks that we created a UUID1 with the proper timestamp assert new_dt == dt + class TestInet(BaseCassEngTestCase): class InetTestModel(Model): @@ -412,4 +838,3 @@ def test_non_address_fails(self): # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation? with self.assertRaises(InvalidRequest): self.InetTestModel.create(address="what is going on here?") - diff --git a/tests/integration/cqlengine/columns/test_value_io.py b/tests/integration/cqlengine/columns/test_value_io.py index 455f80c1e4..faca854fdb 100644 --- a/tests/integration/cqlengine/columns/test_value_io.py +++ b/tests/integration/cqlengine/columns/test_value_io.py @@ -1,25 +1,23 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from datetime import datetime, timedelta, time from decimal import Decimal from uuid import uuid1, uuid4, UUID -import six from cassandra.cqlengine import columns from cassandra.cqlengine.management import sync_table @@ -61,7 +59,6 @@ def setUpClass(cls): # create a table with the given column class IOTestModel(Model): - table_name = cls.column.db_type + "_io_test_model_{0}".format(uuid4().hex[:8]) pkey = cls.column(primary_key=True) data = cls.column() @@ -105,15 +102,15 @@ def test_column_io(self): class TestBlobIO(BaseColumnIOTest): column = columns.Blob - pkey_val = six.b('blake'), uuid4().bytes - data_val = six.b('eggleston'), uuid4().bytes + pkey_val = b'blake', uuid4().bytes + data_val = b'eggleston', uuid4().bytes class TestBlobIO2(BaseColumnIOTest): column = columns.Blob - pkey_val = bytearray(six.b('blake')), uuid4().bytes - data_val = bytearray(six.b('eggleston')), uuid4().bytes + pkey_val = bytearray(b'blake'), uuid4().bytes + data_val = bytearray(b'eggleston'), uuid4().bytes class TestTextIO(BaseColumnIOTest): diff --git a/tests/integration/cqlengine/conftest.py b/tests/integration/cqlengine/conftest.py new file mode 100644 index 0000000000..2dc695828b --- /dev/null +++ b/tests/integration/cqlengine/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from tests.integration import teardown_package as parent_teardown_package +from tests.integration.cqlengine import setup_package, teardown_package + + +@pytest.fixture(scope='session', autouse=True) +def setup_and_teardown_packages(): + setup_package() + yield + teardown_package() + parent_teardown_package() \ No newline at end of file diff --git a/tests/integration/cqlengine/connections/__init__.py b/tests/integration/cqlengine/connections/__init__.py index 1c7af46e71..635f0d9e60 100644 --- a/tests/integration/cqlengine/connections/__init__.py +++ b/tests/integration/cqlengine/connections/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py index 80771d3697..2235fc0c56 100644 --- a/tests/integration/cqlengine/connections/test_connection.py +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,22 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest +from cassandra import ConsistencyLevel from cassandra.cqlengine.models import Model -from cassandra.cqlengine import columns, connection +from cassandra.cqlengine import columns, connection, models from cassandra.cqlengine.management import sync_table -from cassandra.cluster import Cluster +from cassandra.cluster import ExecutionProfile, _clusters_for_shutdown, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.policies import RoundRobinPolicy from cassandra.query import dict_factory -from tests.integration import PROTOCOL_VERSION, execute_with_long_wait_retry +from tests.integration import CASSANDRA_IP, PROTOCOL_VERSION, execute_with_long_wait_retry, local, TestCluster from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE, setup_connection -from cassandra.cqlengine import models class TestConnectModel(Model): @@ -36,15 +36,46 @@ class TestConnectModel(Model): keyspace = columns.Text() -class ConnectionTest(BaseCassEngTestCase): +class ConnectionTest(unittest.TestCase): + def tearDown(self): + connection.unregister_connection("default") + + @local + def test_connection_setup_with_setup(self): + connection.setup(hosts=None, default_keyspace=None) + self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + + @local + def test_connection_setup_with_default(self): + connection.default() + self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + + def test_only_one_connection_is_created(self): + """ + Test to ensure that only one new connection is created by + connection.register_connection + + @since 3.12 + @jira_ticket PYTHON-814 + @expected_result Only one connection is created + + @test_category object_mapper + """ + number_of_clusters_before = len(_clusters_for_shutdown) + connection.default() + number_of_clusters_after = len(_clusters_for_shutdown) + self.assertEqual(number_of_clusters_after - number_of_clusters_before, 1) + + +class SeveralConnectionsTest(BaseCassEngTestCase): @classmethod def setUpClass(cls): - cls.original_cluster = connection.cluster + connection.unregister_connection('default') cls.keyspace1 = 'ctest1' cls.keyspace2 = 'ctest2' - super(ConnectionTest, cls).setUpClass() - cls.setup_cluster = Cluster(protocol_version=PROTOCOL_VERSION) + super(SeveralConnectionsTest, cls).setUpClass() + cls.setup_cluster = TestCluster() cls.setup_session = cls.setup_cluster.connect() ddl = "CREATE KEYSPACE {0} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '{1}'}}".format(cls.keyspace1, 1) execute_with_long_wait_retry(cls.setup_session, ddl) @@ -56,13 +87,12 @@ def tearDownClass(cls): execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace1)) execute_with_long_wait_retry(cls.setup_session, "DROP KEYSPACE {0}".format(cls.keyspace2)) models.DEFAULT_KEYSPACE = DEFAULT_KEYSPACE - cls.original_cluster.shutdown() cls.setup_cluster.shutdown() setup_connection(DEFAULT_KEYSPACE) models.DEFAULT_KEYSPACE def setUp(self): - self.c = Cluster(protocol_version=PROTOCOL_VERSION) + self.c = TestCluster() self.session1 = self.c.connect(keyspace=self.keyspace1) self.session1.row_factory = dict_factory self.session2 = self.c.connect(keyspace=self.keyspace2) @@ -74,7 +104,7 @@ def tearDown(self): def test_connection_session_switch(self): """ Test to ensure that when the default keyspace is changed in a session and that session, - is set in the connection class, that the new defaul keyspace is honored. + is set in the connection class, that the new default keyspace is honored. @since 3.1 @jira_ticket PYTHON-486 @@ -96,3 +126,72 @@ def test_connection_session_switch(self): self.assertEqual(1, TestConnectModel.objects.count()) self.assertEqual(TestConnectModel.objects.first(), TCM2) + +class ConnectionModel(Model): + key = columns.Integer(primary_key=True) + some_data = columns.Text() + + +class ConnectionInitTest(unittest.TestCase): + def test_default_connection_uses_legacy(self): + connection.default() + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_connection_with_legacy_settings(self): + connection.setup( + hosts=[CASSANDRA_IP], + default_keyspace=DEFAULT_KEYSPACE, + consistency=ConsistencyLevel.LOCAL_ONE + ) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_connection_from_session_with_execution_profile(self): + cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + session = cluster.connect() + connection.default() + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.PROFILES) + + def test_connection_from_session_with_legacy_settings(self): + cluster = TestCluster(load_balancing_policy=RoundRobinPolicy()) + session = cluster.connect() + session.row_factory = dict_factory + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_uncommitted_session_uses_legacy(self): + cluster = TestCluster() + session = cluster.connect() + session.row_factory = dict_factory + connection.set_session(session) + conn = connection.get_connection() + self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + + def test_legacy_insert_query(self): + connection.setup( + hosts=[CASSANDRA_IP], + default_keyspace=DEFAULT_KEYSPACE, + consistency=ConsistencyLevel.LOCAL_ONE + ) + self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.LEGACY) + + sync_table(ConnectionModel) + ConnectionModel.objects.create(key=0, some_data='text0') + ConnectionModel.objects.create(key=1, some_data='text1') + self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') + + def test_execution_profile_insert_query(self): + cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) + session = cluster.connect() + connection.default() + connection.set_session(session) + self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.PROFILES) + + sync_table(ConnectionModel) + ConnectionModel.objects.create(key=0, some_data='text0') + ConnectionModel.objects.create(key=1, some_data='text1') + self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') diff --git a/tests/integration/cqlengine/management/__init__.py b/tests/integration/cqlengine/management/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/integration/cqlengine/management/__init__.py +++ b/tests/integration/cqlengine/management/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/management/test_compaction_settings.py b/tests/integration/cqlengine/management/test_compaction_settings.py index 4fe349e069..fbb5870ebb 100644 --- a/tests/integration/cqlengine/management/test_compaction_settings.py +++ b/tests/integration/cqlengine/management/test_compaction_settings.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +15,7 @@ # limitations under the License. import copy -from mock import patch -import six +from unittest.mock import patch from cassandra.cqlengine import columns from cassandra.cqlengine.management import drop_table, sync_table, _get_table_metadata, _update_options @@ -83,7 +84,7 @@ def test_alter_actually_alters(self): table_meta = _get_table_metadata(tmp) - self.assertRegexpMatches(table_meta.export_as_string(), '.*SizeTieredCompactionStrategy.*') + self.assertRegex(table_meta.export_as_string(), '.*SizeTieredCompactionStrategy.*') def test_alter_options(self): @@ -97,11 +98,11 @@ class AlterTable(Model): drop_table(AlterTable) sync_table(AlterTable) table_meta = _get_table_metadata(AlterTable) - self.assertRegexpMatches(table_meta.export_as_string(), ".*'sstable_size_in_mb': '64'.*") + self.assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '64'.*") AlterTable.__options__['compaction']['sstable_size_in_mb'] = '128' sync_table(AlterTable) table_meta = _get_table_metadata(AlterTable) - self.assertRegexpMatches(table_meta.export_as_string(), ".*'sstable_size_in_mb': '128'.*") + self.assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '128'.*") class OptionsTest(BaseCassEngTestCase): @@ -110,7 +111,7 @@ def _verify_options(self, table_meta, expected_options): cql = table_meta.export_as_string() for name, value in expected_options.items(): - if isinstance(value, six.string_types): + if isinstance(value, str): self.assertIn("%s = '%s'" % (name, value), cql) else: start = cql.find("%s = {" % (name,)) diff --git a/tests/integration/cqlengine/management/test_management.py b/tests/integration/cqlengine/management/test_management.py index e4b35e2136..c424c187ce 100644 --- a/tests/integration/cqlengine/management/test_management.py +++ b/tests/integration/cqlengine/management/test_management.py @@ -1,23 +1,23 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -import mock +from unittest import mock import logging +from packaging.version import Version from cassandra.cqlengine.connection import get_session, get_cluster from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine import management @@ -25,13 +25,16 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns -from tests.integration import PROTOCOL_VERSION, greaterthancass20, MockLoggingHandler, CASSANDRA_VERSION +from tests.integration import DSE_VERSION, PROTOCOL_VERSION, greaterthancass20, MockLoggingHandler, CASSANDRA_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.query.test_queryset import TestModel from cassandra.cqlengine.usertype import UserType from tests.integration.cqlengine import DEFAULT_KEYSPACE +INCLUDE_REPAIR = not CASSANDRA_VERSION >= Version('4-a') # This should cover DSE 6.0+ + + class KeyspaceManagementTest(BaseCassEngTestCase): def test_create_drop_succeeeds(self): cluster = get_cluster() @@ -78,9 +81,45 @@ class CapitalizedKeyModel(Model): class PrimaryKeysOnlyModel(Model): + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysModelChanged(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + new_first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysModelTypeChanged(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + first_key = columns.Float(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysRemovedPk(Model): + + __table_name__ = "primary_keys_only" __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} - first_ey = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class PrimaryKeysAddedClusteringKey(Model): + + __table_name__ = "primary_keys_only" + __options__ = {'compaction': {'class': 'LeveledCompactionStrategy'}} + + new_first_key = columns.Float(primary_key=True) second_key = columns.Integer(primary_key=True) @@ -169,9 +208,13 @@ class ModelWithTableProperties(Model): __options__ = {'bloom_filter_fp_chance': '0.76328', 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', - 'gc_grace_seconds': '2063', - 'read_repair_chance': '0.17985', - 'dclocal_read_repair_chance': '0.50811'} + 'gc_grace_seconds': '2063'} + + if INCLUDE_REPAIR: + __options__.update( + {'read_repair_chance': '0.17985', + 'dclocal_read_repair_chance': '0.50811'} + ) key = columns.UUID(primary_key=True) @@ -187,13 +230,15 @@ def test_set_table_properties(self): expected = {'bloom_filter_fp_chance': 0.76328, 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', 'gc_grace_seconds': 2063, - 'read_repair_chance': 0.17985, # For some reason 'dclocal_read_repair_chance' in CQL is called # just 'local_read_repair_chance' in the schema table. # Source: https://issues.apache.org/jira/browse/CASSANDRA-6717 # TODO: due to a bug in the native driver i'm not seeing the local read repair chance show up # 'local_read_repair_chance': 0.50811, } + if INCLUDE_REPAIR: + expected.update({'read_repair_chance': 0.17985}) + options = management._get_table_metadata(ModelWithTableProperties).options self.assertEqual(dict([(k, options.get(k)) for k in expected.keys()]), expected) @@ -203,21 +248,22 @@ def test_table_property_update(self): ModelWithTableProperties.__options__['comment'] = 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy' ModelWithTableProperties.__options__['gc_grace_seconds'] = 96362 - ModelWithTableProperties.__options__['read_repair_chance'] = 0.2989 - ModelWithTableProperties.__options__['dclocal_read_repair_chance'] = 0.12732 + if INCLUDE_REPAIR: + ModelWithTableProperties.__options__['read_repair_chance'] = 0.2989 + ModelWithTableProperties.__options__['dclocal_read_repair_chance'] = 0.12732 sync_table(ModelWithTableProperties) table_options = management._get_table_metadata(ModelWithTableProperties).options - self.assertDictContainsSubset(ModelWithTableProperties.__options__, table_options) + self.assertLessEqual(ModelWithTableProperties.__options__.items(), table_options.items()) def test_bogus_option_update(self): sync_table(ModelWithTableProperties) option = 'no way will this ever be an option' try: ModelWithTableProperties.__options__[option] = 'what was I thinking?' - self.assertRaisesRegexp(KeyError, "Invalid table option.*%s.*" % option, sync_table, ModelWithTableProperties) + self.assertRaisesRegex(KeyError, "Invalid table option.*%s.*" % option, sync_table, ModelWithTableProperties) finally: ModelWithTableProperties.__options__.pop(option, None) @@ -242,6 +288,21 @@ def test_sync_table_works_with_primary_keys_only_tables(self): table_meta = management._get_table_metadata(PrimaryKeysOnlyModel) self.assertIn('SizeTieredCompactionStrategy', table_meta.as_cql_query()) + def test_primary_key_validation(self): + """ + Test to ensure that changes to primary keys throw CQLEngineExceptions + + @since 3.2 + @jira_ticket PYTHON-532 + @expected_result Attempts to modify primary keys throw an exception + + @test_category object_mapper + """ + sync_table(PrimaryKeysOnlyModel) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysModelChanged) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysAddedClusteringKey) + self.assertRaises(CQLEngineException, sync_table, PrimaryKeysRemovedPk) + class IndexModel(Model): @@ -307,7 +368,7 @@ def test_sync_warnings(self): sync_table(BaseInconsistent) sync_table(ChangedInconsistent) self.assertTrue('differing from the model type' in mock_handler.messages.get('warning')[0]) - if CASSANDRA_VERSION >= '2.1': + if CASSANDRA_VERSION >= Version('2.1'): sync_type(DEFAULT_KEYSPACE, BaseInconsistentType) mock_handler.reset() sync_type(DEFAULT_KEYSPACE, ChangedInconsistentType) @@ -315,6 +376,14 @@ def test_sync_warnings(self): logger.removeHandler(mock_handler) +class TestIndexSetModel(Model): + partition = columns.UUID(primary_key=True) + int_set = columns.Set(columns.Integer, index=True) + int_list = columns.List(columns.Integer, index=True) + text_map = columns.Map(columns.Text, columns.DateTime, index=True) + mixed_tuple = columns.Tuple(columns.Text, columns.Integer, columns.Text, index=True) + + class IndexTests(BaseCassEngTestCase): def setUp(self): @@ -361,6 +430,24 @@ def test_sync_index_case_sensitive(self): table_meta = management._get_table_metadata(IndexCaseSensitiveModel) self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + @greaterthancass20 + def test_sync_indexed_set(self): + """ + Tests that models that have container types with indices can be synced. + + @since 3.2 + @jira_ticket PYTHON-533 + @expected_result table_sync should complete without a server error. + + @test_category object_mapper + """ + sync_table(TestIndexSetModel) + table_meta = management._get_table_metadata(TestIndexSetModel) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_set')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_list')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'text_map')) + self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'mixed_tuple')) + class NonModelFailureTest(BaseCassEngTestCase): class FakeModel(object): diff --git a/tests/integration/cqlengine/model/__init__.py b/tests/integration/cqlengine/model/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/integration/cqlengine/model/__init__.py +++ b/tests/integration/cqlengine/model/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py index 8147e41079..00051d9248 100644 --- a/tests/integration/cqlengine/model/test_class_construction.py +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -47,9 +49,30 @@ class TestModel(Model): inst = TestModel() self.assertHasAttr(inst, 'id') self.assertHasAttr(inst, 'text') - self.assertIsNone(inst.id) + self.assertIsNotNone(inst.id) self.assertIsNone(inst.text) + def test_values_on_instantiation(self): + """ + Tests defaults and user-provided values on instantiation. + """ + + class TestPerson(Model): + first_name = columns.Text(primary_key=True, default='kevin') + last_name = columns.Text(default='deldycke') + + # Check that defaults are available at instantiation. + inst1 = TestPerson() + self.assertHasAttr(inst1, 'first_name') + self.assertHasAttr(inst1, 'last_name') + self.assertEqual(inst1.first_name, 'kevin') + self.assertEqual(inst1.last_name, 'deldycke') + + # Check that values on instantiation overrides defaults. + inst2 = TestPerson(first_name='bob', last_name='joe') + self.assertEqual(inst2.first_name, 'bob') + self.assertEqual(inst2.last_name, 'joe') + def test_db_map(self): """ Tests that the db_map is properly defined @@ -70,7 +93,7 @@ def test_attempting_to_make_duplicate_column_names_fails(self): Tests that trying to create conflicting db column names will fail """ - with self.assertRaisesRegexp(ModelException, r".*more than once$"): + with self.assertRaisesRegex(ModelException, r".*more than once$"): class BadNames(Model): words = columns.Text(primary_key=True) content = columns.Text(db_field='words') diff --git a/tests/integration/cqlengine/model/test_equality_operations.py b/tests/integration/cqlengine/model/test_equality_operations.py index 9391ce6a79..89045d7714 100644 --- a/tests/integration/cqlengine/model/test_equality_operations.py +++ b/tests/integration/cqlengine/model/test_equality_operations.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index e46698ff75..c2c4906441 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -1,28 +1,29 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from mock import patch +from unittest.mock import patch from cassandra.cqlengine import columns, CQLEngineException from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace from cassandra.cqlengine import models from cassandra.cqlengine.models import Model, ModelDefinitionException - +from uuid import uuid1 +from tests.integration import pypy +from tests.integration.cqlengine.base import TestQueryUpdateModel class TestModel(unittest.TestCase): """ Tests the non-io functionality of models """ @@ -173,3 +174,97 @@ class IllegalFilterColumnModel(Model): filter = columns.Text() +class ModelOverWriteTest(unittest.TestCase): + + def test_model_over_write(self): + """ + Test to ensure overwriting of primary keys in model inheritance is allowed + + This is currently only an issue in PyPy. When PYTHON-504 is introduced this should + be updated error out and warn the user + + @since 3.6.0 + @jira_ticket PYTHON-576 + @expected_result primary keys can be overwritten via inheritance + + @test_category object_mapper + """ + class TimeModelBase(Model): + uuid = columns.TimeUUID(primary_key=True) + + class DerivedTimeModel(TimeModelBase): + __table_name__ = 'derived_time' + uuid = columns.TimeUUID(primary_key=True, partition_key=True) + value = columns.Text(required=False) + + # In case the table already exists in keyspace + drop_table(DerivedTimeModel) + + sync_table(DerivedTimeModel) + uuid_value = uuid1() + uuid_value2 = uuid1() + DerivedTimeModel.create(uuid=uuid_value, value="first") + DerivedTimeModel.create(uuid=uuid_value2, value="second") + DerivedTimeModel.objects.filter(uuid=uuid_value) + + +class TestColumnComparison(unittest.TestCase): + def test_comparison(self): + l = [TestQueryUpdateModel.partition.column, + TestQueryUpdateModel.cluster.column, + TestQueryUpdateModel.count.column, + TestQueryUpdateModel.text.column, + TestQueryUpdateModel.text_set.column, + TestQueryUpdateModel.text_list.column, + TestQueryUpdateModel.text_map.column] + + self.assertEqual(l, sorted(l)) + self.assertNotEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) + self.assertLessEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) + self.assertGreater(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) + self.assertGreaterEqual(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) + + +class TestDeprecationWarning(unittest.TestCase): + def test_deprecation_warnings(self): + """ + Test to some deprecation warning have been added. It tests warnings for + negative index, negative index slicing and table sensitive removal + + This test should be removed in 4.0, that's why the imports are in + this test, so it's easier to remove + + @since 3.13 + @jira_ticket PYTHON-877 + @expected_result the deprecation warnings are emitted + + @test_category logs + """ + import warnings + + class SensitiveModel(Model): + __table_name__ = 'SensitiveModel' + __table_name_case_sensitive__ = True + k = columns.Integer(primary_key=True) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + sync_table(SensitiveModel) + self.addCleanup(drop_table, SensitiveModel) + + SensitiveModel.create(k=0) + + rows = SensitiveModel.objects().all().allow_filtering() + rows[-1] + rows[-1:] + + # Asyncio complains loudly about old syntax on python 3.7+, so get rid of all of those + relevant_warnings = [warn for warn in w if "with (yield from lock)" not in str(warn.message)] + + self.assertEqual(len(relevant_warnings), 4) + self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(relevant_warnings[0].message)) + self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(relevant_warnings[1].message)) + self.assertIn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", + str(relevant_warnings[2].message)) + self.assertIn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", + str(relevant_warnings[3].message)) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index b05c8b7cb4..9cff0af6a6 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -1,36 +1,40 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from uuid import uuid4, UUID import random -from datetime import datetime, date, time +from datetime import datetime, date, time, timezone from decimal import Decimal from operator import itemgetter +import cassandra from cassandra.cqlengine import columns from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.management import sync_table from cassandra.cqlengine.management import drop_table from cassandra.cqlengine.models import Model -from cassandra.util import Date, Time +from cassandra.query import SimpleStatement +from cassandra.util import Date, Time, Duration +from cassandra.cqlengine.statements import SelectStatement, DeleteStatement, WhereClause +from cassandra.cqlengine.operators import EqualsOperator -from tests.integration import PROTOCOL_VERSION +from tests.integration import PROTOCOL_VERSION, greaterthanorequalcass3_10 from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE class TestModel(Model): @@ -65,7 +69,7 @@ def tearDownClass(cls): def test_model_save_and_load(self): """ - Tests that models can be saved and retrieved + Tests that models can be saved and retrieved, using the create method. """ tm = TestModel.create(count=8, text='123456789') self.assertIsInstance(tm, TestModel) @@ -76,6 +80,22 @@ def test_model_save_and_load(self): for cname in tm._columns.keys(): self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + def test_model_instantiation_save_and_load(self): + """ + Tests that models can be saved and retrieved, this time using the + natural model instantiation. + """ + tm = TestModel(count=8, text='123456789') + # Tests that values are available on instantiation. + self.assertIsNotNone(tm['id']) + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '123456789') + tm.save() + tm2 = TestModel.objects(id=tm.id).first() + + for cname in tm._columns.keys(): + self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + def test_model_read_as_dict(self): """ Tests that columns of an instance can be read as a dict. @@ -142,6 +162,7 @@ def test_a_sensical_error_is_raised_if_you_try_to_create_a_table_twice(self): sync_table(TestModel) sync_table(TestModel) + @greaterthanorequalcass3_10 def test_can_insert_model_with_all_column_types(self): """ Test for inserting all column types into a Model @@ -174,22 +195,24 @@ class AllDatatypesModel(Model): l = columns.TimeUUID() m = columns.UUID() n = columns.VarInt() + o = columns.Duration() sync_table(AllDatatypesModel) - input = ['ascii', 2 ** 63 - 1, bytearray(b'hello world'), True, datetime.utcfromtimestamp(872835240), + input = ['ascii', 2 ** 63 - 1, bytearray(b'hello world'), True, datetime.fromtimestamp(872835240, tz=timezone.utc).replace(tzinfo=None), Decimal('12.3E+7'), 2.39, 3.4028234663852886e+38, '123.123.123.123', 2147483647, 'text', UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), int(str(2147483647) + '000')] AllDatatypesModel.create(id=0, a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True, - e=datetime.utcfromtimestamp(872835240), f=Decimal('12.3E+7'), g=2.39, + e=datetime.fromtimestamp(872835240, tz=timezone.utc), f=Decimal('12.3E+7'), g=2.39, h=3.4028234663852886e+38, i='123.123.123.123', j=2147483647, k='text', l=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), - m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000')) + m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000'), + o=Duration(2, 3, 4)) self.assertEqual(1, AllDatatypesModel.objects.count()) - output = AllDatatypesModel.objects().first() + output = AllDatatypesModel.objects.first() for i, i_char in enumerate(range(ord('a'), ord('a') + 14)): self.assertEqual(input[i], output[chr(i_char)]) @@ -242,7 +265,7 @@ class v4DatatypesModel(Model): v4DatatypesModel.create(id=0, a=date(1970, 1, 1), b=32523, c=time(16, 47, 25, 7), d=123) self.assertEqual(1, v4DatatypesModel.objects.count()) - output = v4DatatypesModel.objects().first() + output = v4DatatypesModel.objects.first() for i, i_char in enumerate(range(ord('a'), ord('a') + 3)): self.assertEqual(input[i], output[chr(i_char)]) @@ -267,16 +290,16 @@ class FloatingPointModel(Model): sync_table(FloatingPointModel) FloatingPointModel.create(id=0, f=2.39) - output = FloatingPointModel.objects().first() + output = FloatingPointModel.objects.first() self.assertEqual(2.390000104904175, output.f) # float loses precision FloatingPointModel.create(id=0, f=3.4028234663852886e+38, d=2.39) - output = FloatingPointModel.objects().first() + output = FloatingPointModel.objects.first() self.assertEqual(3.4028234663852886e+38, output.f) self.assertEqual(2.39, output.d) # double retains precision FloatingPointModel.create(id=0, d=3.4028234663852886e+38) - output = FloatingPointModel.objects().first() + output = FloatingPointModel.objects.first() self.assertEqual(3.4028234663852886e+38, output.d) @@ -461,6 +484,49 @@ def test_previous_value_tracking_on_instantiation(self): self.assertTrue(self.instance._values['count'].previous_value is None) self.assertTrue(self.instance.count is None) + def test_previous_value_tracking_on_instantiation_with_default(self): + + class TestDefaultValueTracking(Model): + id = columns.Integer(partition_key=True) + int1 = columns.Integer(default=123) + int2 = columns.Integer(default=456) + int3 = columns.Integer(default=lambda: random.randint(0, 1000)) + int4 = columns.Integer(default=lambda: random.randint(0, 1000)) + int5 = columns.Integer() + int6 = columns.Integer() + + instance = TestDefaultValueTracking( + id=1, + int1=9999, + int3=7777, + int5=5555) + + self.assertEqual(instance.id, 1) + self.assertEqual(instance.int1, 9999) + self.assertEqual(instance.int2, 456) + self.assertEqual(instance.int3, 7777) + self.assertIsNotNone(instance.int4) + self.assertIsInstance(instance.int4, int) + self.assertGreaterEqual(instance.int4, 0) + self.assertLessEqual(instance.int4, 1000) + self.assertEqual(instance.int5, 5555) + self.assertTrue(instance.int6 is None) + + # All previous values are unset as the object hasn't been persisted + # yet. + self.assertTrue(instance._values['id'].previous_value is None) + self.assertTrue(instance._values['int1'].previous_value is None) + self.assertTrue(instance._values['int2'].previous_value is None) + self.assertTrue(instance._values['int3'].previous_value is None) + self.assertTrue(instance._values['int4'].previous_value is None) + self.assertTrue(instance._values['int5'].previous_value is None) + self.assertTrue(instance._values['int6'].previous_value is None) + + # All explicitely set columns, and those with default values are + # flagged has changed. + self.assertTrue(set(instance.get_changed_columns()) == set([ + 'id', 'int1', 'int3', 'int5'])) + def test_save_to_none(self): """ Test update of column value of None with save() function. @@ -638,6 +704,201 @@ def test_query_with_date(self): self.assertTrue(inst.date == day) +class BasicModelNoRouting(Model): + __table_name__ = 'basic_model_no_routing' + __compute_routing_key__ = False + k = columns.Integer(primary_key=True) + v = columns.Integer() + + +class BasicModel(Model): + __table_name__ = 'basic_model_routing' + k = columns.Integer(primary_key=True) + v = columns.Integer() + + +class BasicModelMulti(Model): + __table_name__ = 'basic_model_routing_multi' + k = columns.Integer(partition_key=True) + v = columns.Integer(partition_key=True) + + +class ComplexModelRouting(Model): + __table_name__ = 'complex_model_routing' + partition = columns.UUID(partition_key=True, default=uuid4) + cluster = columns.Integer(partition_key=True) + count = columns.Integer() + text = columns.Text(partition_key=True) + float = columns.Float(partition_key=True) + text_2 = columns.Text() + + +class TestModelRoutingKeys(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestModelRoutingKeys, cls).setUpClass() + sync_table(BasicModelNoRouting) + sync_table(BasicModel) + sync_table(BasicModelMulti) + sync_table(ComplexModelRouting) + + @classmethod + def tearDownClass(cls): + super(TestModelRoutingKeys, cls).tearDownClass() + drop_table(BasicModelNoRouting) + drop_table(BasicModel) + drop_table(BasicModelMulti) + drop_table(ComplexModelRouting) + + def test_routing_key_is_ignored(self): + """ + Compares the routing key generated by simple partition key using the model with the one generated by the equivalent + bound statement. It also verifies basic operations work with no routing key + @since 3.2 + @jira_ticket PYTHON-505 + @expected_result they shouldn't match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_no_routing (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + + mrk = BasicModelNoRouting._routing_key_from_values([1], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertNotEqual(bound.routing_key, simple.routing_key) + + # Verify that basic create, update and delete work with no routing key + t = BasicModelNoRouting.create(k=2, v=3) + t.update(v=4).save() + f = BasicModelNoRouting.objects.filter(k=2).first() + self.assertEqual(t, f) + + t.delete() + self.assertEqual(BasicModelNoRouting.objects.count(), 0) + + + def test_routing_key_generation_basic(self): + """ + Compares the routing key generated by simple partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_routing (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + + mrk = BasicModel._routing_key_from_values([1], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_routing_key_generation_multi(self): + """ + Compares the routing key generated by composite partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + + prepared = self.session.prepare( + """ + INSERT INTO {0}.basic_model_routing_multi (k, v) VALUES (?, ?) + """.format(DEFAULT_KEYSPACE)) + bound = prepared.bind((1, 2)) + mrk = BasicModelMulti._routing_key_from_values([1, 2], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_routing_key_generation_complex(self): + """ + Compares the routing key generated by complex composite partition key using the model with the one generated by the equivalent + bound statement + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result they should match + + @test_category object_mapper + """ + prepared = self.session.prepare( + """ + INSERT INTO {0}.complex_model_routing (partition, cluster, count, text, float, text_2) VALUES (?, ?, ?, ?, ?, ?) + """.format(DEFAULT_KEYSPACE)) + partition = uuid4() + cluster = 1 + count = 2 + text = "text" + float = 1.2 + text_2 = "text_2" + bound = prepared.bind((partition, cluster, count, text, float, text_2)) + mrk = ComplexModelRouting._routing_key_from_values([partition, cluster, text, float], self.session.cluster.protocol_version) + simple = SimpleStatement("") + simple.routing_key = mrk + self.assertEqual(bound.routing_key, simple.routing_key) + + def test_partition_key_index(self): + """ + Test to ensure that statement partition key generation is in the correct order + @since 3.2 + @jira_ticket PYTHON-535 + @expected_result . + + @test_category object_mapper + """ + self._check_partition_value_generation(BasicModel, SelectStatement(BasicModel.__table_name__)) + self._check_partition_value_generation(BasicModel, DeleteStatement(BasicModel.__table_name__)) + self._check_partition_value_generation(BasicModelMulti, SelectStatement(BasicModelMulti.__table_name__)) + self._check_partition_value_generation(BasicModelMulti, DeleteStatement(BasicModelMulti.__table_name__)) + self._check_partition_value_generation(ComplexModelRouting, SelectStatement(ComplexModelRouting.__table_name__)) + self._check_partition_value_generation(ComplexModelRouting, DeleteStatement(ComplexModelRouting.__table_name__)) + self._check_partition_value_generation(BasicModel, SelectStatement(BasicModel.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModel, DeleteStatement(BasicModel.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModelMulti, SelectStatement(BasicModelMulti.__table_name__), reverse=True) + self._check_partition_value_generation(BasicModelMulti, DeleteStatement(BasicModelMulti.__table_name__), reverse=True) + self._check_partition_value_generation(ComplexModelRouting, SelectStatement(ComplexModelRouting.__table_name__), reverse=True) + self._check_partition_value_generation(ComplexModelRouting, DeleteStatement(ComplexModelRouting.__table_name__), reverse=True) + + def _check_partition_value_generation(self, model, state, reverse=False): + """ + This generates a some statements based on the partition_key_index of the model. + It then validates that order of the partition key values in the statement matches the index + specified in the models partition_key_index + """ + # Setup some unique values for statement generation + uuid = uuid4() + values = {'k': 5, 'v': 3, 'partition': uuid, 'cluster': 6, 'count': 42, 'text': 'text', 'float': 3.1415, 'text_2': 'text_2'} + res = dict((v, k) for k, v in values.items()) + items = list(model._partition_key_index.items()) + if(reverse): + items.reverse() + # Add where clauses for each partition key + for partition_key, position in items: + wc = WhereClause(partition_key, EqualsOperator(), values.get(partition_key)) + state._add_where_clause(wc) + + # Iterate over the partition key values check to see that their index matches + # Those specified in the models partition field + for indx, value in enumerate(state.partition_key_values(model._partition_key_index)): + name = res.get(value) + self.assertEqual(indx, model._partition_key_index.get(name)) + + def test_none_filter_fails(): class NoneFilterModel(Model): diff --git a/tests/integration/cqlengine/model/test_polymorphism.py b/tests/integration/cqlengine/model/test_polymorphism.py index 18feb653c5..fc5e9c57ff 100644 --- a/tests/integration/cqlengine/model/test_polymorphism.py +++ b/tests/integration/cqlengine/model/test_polymorphism.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +15,7 @@ # limitations under the License. import uuid -import mock +from unittest import mock from cassandra.cqlengine import columns from cassandra.cqlengine import models diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py index dc7eb134c0..bab9c51c1f 100644 --- a/tests/integration/cqlengine/model/test_udts.py +++ b/tests/integration/cqlengine/model/test_udts.py @@ -1,34 +1,67 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from datetime import datetime, date, time +from datetime import datetime, date, time, timezone from decimal import Decimal -from mock import Mock +from unittest.mock import Mock from uuid import UUID, uuid4 from cassandra.cqlengine.models import Model from cassandra.cqlengine.usertype import UserType, UserTypeDefinitionException from cassandra.cqlengine import columns, connection -from cassandra.cqlengine.management import sync_table, sync_type, create_keyspace_simple, drop_keyspace, drop_table +from cassandra.cqlengine.management import sync_table, drop_table, sync_type, create_keyspace_simple, drop_keyspace +from cassandra.cqlengine import ValidationError from cassandra.util import Date, Time from tests.integration import PROTOCOL_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import DEFAULT_KEYSPACE + + +class User(UserType): + age = columns.Integer() + name = columns.Text() + + +class UserModel(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(User) + + +class AllDatatypes(UserType): + a = columns.Ascii() + b = columns.BigInt() + c = columns.Blob() + d = columns.Boolean() + e = columns.DateTime() + f = columns.Decimal() + g = columns.Double() + h = columns.Float() + i = columns.Inet() + j = columns.Integer() + k = columns.Text() + l = columns.TimeUUID() + m = columns.UUID() + n = columns.VarInt() + + +class AllDatatypesModel(Model): + id = columns.Integer(primary_key=True) + data = columns.UserDefinedType(AllDatatypes) class UserDefinedTypeTests(BaseCassEngTestCase): @@ -42,7 +75,7 @@ class User(UserType): age = columns.Integer() name = columns.Text() - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John") self.assertEqual(42, user.age) self.assertEqual("John", user.name) @@ -53,8 +86,10 @@ class User(UserType): name = columns.Text() gender = columns.Text() - sync_type("cqlengine_test", User) - user = User(age=42, name="John", gender="male") + sync_type(DEFAULT_KEYSPACE, User) + user = User(age=42) + user["name"] = "John" + user["gender"] = "male" self.assertEqual(42, user.age) self.assertEqual("John", user.name) self.assertEqual("male", user.gender) @@ -64,117 +99,94 @@ class User(UserType): age = columns.Integer() name = columns.Text() - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John", gender="male") with self.assertRaises(AttributeError): user.gender def test_can_insert_udts(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) sync_table(UserModel) + self.addCleanup(drop_table, UserModel) user = User(age=42, name="John") UserModel.create(id=0, info=user) self.assertEqual(1, UserModel.objects.count()) - john = UserModel.objects().first() + john = UserModel.objects.first() self.assertEqual(0, john.id) self.assertTrue(type(john.info) is User) self.assertEqual(42, john.info.age) self.assertEqual("John", john.info.name) def test_can_update_udts(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) - sync_table(UserModel) + self.addCleanup(drop_table, UserModel) user = User(age=42, name="John") created_user = UserModel.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModel.objects.first().info self.assertEqual(42, john_info.age) self.assertEqual("John", john_info.name) created_user.info = User(age=22, name="Mary") created_user.update() - mary_info = UserModel.objects().first().info - self.assertEqual(22, mary_info.age) - self.assertEqual("Mary", mary_info.name) + mary_info = UserModel.objects.first().info + self.assertEqual(22, mary_info["age"]) + self.assertEqual("Mary", mary_info["name"]) def test_can_update_udts_with_nones(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - class UserModel(Model): - id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) - sync_table(UserModel) + self.addCleanup(drop_table, UserModel) user = User(age=42, name="John") created_user = UserModel.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModel.objects.first().info self.assertEqual(42, john_info.age) self.assertEqual("John", john_info.name) created_user.info = None created_user.update() - john_info = UserModel.objects().first().info + john_info = UserModel.objects.first().info self.assertIsNone(john_info) def test_can_create_same_udt_different_keyspaces(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() - - sync_type("cqlengine_test", User) + sync_type(DEFAULT_KEYSPACE, User) create_keyspace_simple("simplex", 1) sync_type("simplex", User) drop_keyspace("simplex") def test_can_insert_partial_udts(self): - class User(UserType): + class UserGender(UserType): age = columns.Integer() name = columns.Text() gender = columns.Text() - class UserModel(Model): + class UserModelGender(Model): id = columns.Integer(primary_key=True) - info = columns.UserDefinedType(User) + info = columns.UserDefinedType(UserGender) - sync_table(UserModel) + sync_table(UserModelGender) + self.addCleanup(drop_table, UserModelGender) - user = User(age=42, name="John") - UserModel.create(id=0, info=user) + user = UserGender(age=42, name="John") + UserModelGender.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModelGender.objects.first().info self.assertEqual(42, john_info.age) self.assertEqual("John", john_info.name) self.assertIsNone(john_info.gender) - user = User(age=42) - UserModel.create(id=0, info=user) + user = UserGender(age=42) + UserModelGender.create(id=0, info=user) - john_info = UserModel.objects().first().info + john_info = UserModelGender.objects.first().info self.assertEqual(42, john_info.age) self.assertIsNone(john_info.name) self.assertIsNone(john_info.gender) @@ -201,6 +213,7 @@ class DepthModel(Model): v_3 = columns.UserDefinedType(Depth_3) sync_table(DepthModel) + self.addCleanup(drop_table, DepthModel) udts = [Depth_0(age=42, name="John")] udts.append(Depth_1(value=udts[0])) @@ -208,7 +221,7 @@ class DepthModel(Model): udts.append(Depth_3(value=udts[2])) DepthModel.create(id=0, v_0=udts[0], v_1=udts[1], v_2=udts[2], v_3=udts[3]) - output = DepthModel.objects().first() + output = DepthModel.objects.first() self.assertEqual(udts[0], output.v_0) self.assertEqual(udts[1], output.v_1) @@ -230,28 +243,8 @@ def test_can_insert_udts_with_nones(self): @test_category data_types:udt """ - - class AllDatatypes(UserType): - a = columns.Ascii() - b = columns.BigInt() - c = columns.Blob() - d = columns.Boolean() - e = columns.DateTime() - f = columns.Decimal() - g = columns.Double() - h = columns.Float() - i = columns.Inet() - j = columns.Integer() - k = columns.Text() - l = columns.TimeUUID() - m = columns.UUID() - n = columns.VarInt() - - class AllDatatypesModel(Model): - id = columns.Integer(primary_key=True) - data = columns.UserDefinedType(AllDatatypes) - sync_table(AllDatatypesModel) + self.addCleanup(drop_table, AllDatatypesModel) input = AllDatatypes(a=None, b=None, c=None, d=None, e=None, f=None, g=None, h=None, i=None, j=None, k=None, l=None, m=None, n=None) @@ -259,7 +252,7 @@ class AllDatatypesModel(Model): self.assertEqual(1, AllDatatypesModel.objects.count()) - output = AllDatatypesModel.objects().first().data + output = AllDatatypesModel.objects.first().data self.assertEqual(input, output) def test_can_insert_udts_with_all_datatypes(self): @@ -277,38 +270,18 @@ def test_can_insert_udts_with_all_datatypes(self): @test_category data_types:udt """ - - class AllDatatypes(UserType): - a = columns.Ascii() - b = columns.BigInt() - c = columns.Blob() - d = columns.Boolean() - e = columns.DateTime() - f = columns.Decimal() - g = columns.Double() - h = columns.Float() - i = columns.Inet() - j = columns.Integer() - k = columns.Text() - l = columns.TimeUUID() - m = columns.UUID() - n = columns.VarInt() - - class AllDatatypesModel(Model): - id = columns.Integer(primary_key=True) - data = columns.UserDefinedType(AllDatatypes) - sync_table(AllDatatypesModel) + self.addCleanup(drop_table, AllDatatypesModel) input = AllDatatypes(a='ascii', b=2 ** 63 - 1, c=bytearray(b'hello world'), d=True, - e=datetime.utcfromtimestamp(872835240), f=Decimal('12.3E+7'), g=2.39, + e=datetime.fromtimestamp(872835240, tz=timezone.utc).replace(tzinfo=None), f=Decimal('12.3E+7'), g=2.39, h=3.4028234663852886e+38, i='123.123.123.123', j=2147483647, k='text', l=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000')) AllDatatypesModel.create(id=0, data=input) self.assertEqual(1, AllDatatypesModel.objects.count()) - output = AllDatatypesModel.objects().first().data + output = AllDatatypesModel.objects.first().data for i in range(ord('a'), ord('a') + 14): self.assertEqual(input[chr(i)], output[chr(i)]) @@ -344,12 +317,13 @@ class Allv4DatatypesModel(Model): data = columns.UserDefinedType(Allv4Datatypes) sync_table(Allv4DatatypesModel) + self.addCleanup(drop_table, Allv4DatatypesModel) input = Allv4Datatypes(a=Date(date(1970, 1, 1)), b=32523, c=Time(time(16, 47, 25, 7)), d=123) Allv4DatatypesModel.create(id=0, data=input) self.assertEqual(1, Allv4DatatypesModel.objects.count()) - output = Allv4DatatypesModel.objects().first().data + output = Allv4DatatypesModel.objects.first().data for i in range(ord('a'), ord('a') + 3): self.assertEqual(input[chr(i)], output[chr(i)]) @@ -387,11 +361,13 @@ class Container(Model): # Create table, insert data sync_table(Container) + self.addCleanup(drop_table, Container) + Container.create(id=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), names=names) # Validate input and output matches self.assertEqual(1, Container.objects.count()) - names_output = Container.objects().first().names + names_output = Container.objects.first().names self.assertEqual(names_output, names) def test_udts_with_unicode(self): @@ -410,15 +386,13 @@ def test_udts_with_unicode(self): ascii_name = 'normal name' unicode_name = u'Fran\u00E7ois' - class User(UserType): - age = columns.Integer() - name = columns.Text() - class UserModelText(Model): id = columns.Text(primary_key=True) info = columns.UserDefinedType(User) sync_table(UserModelText) + self.addCleanup(drop_table, UserModelText) + # Two udt instances one with a unicode one with ascii user_template_ascii = User(age=25, name=ascii_name) user_template_unicode = User(age=25, name=unicode_name) @@ -428,9 +402,6 @@ class UserModelText(Model): UserModelText.create(id=unicode_name, info=user_template_unicode) def test_register_default_keyspace(self): - class User(UserType): - age = columns.Integer() - name = columns.Text() from cassandra.cqlengine import models from cassandra.cqlengine import connection @@ -467,6 +438,7 @@ class TheModel(Model): info = columns.UserDefinedType(db_field_different) sync_table(TheModel) + self.addCleanup(drop_table, TheModel) cluster = connection.get_cluster() type_meta = cluster.metadata.keyspaces[TheModel._get_keyspace()].user_types[db_field_different.type_name()] @@ -485,7 +457,7 @@ class TheModel(Model): self.assertEqual(1, TheModel.objects.count()) - john = TheModel.objects().first() + john = TheModel.objects.first() self.assertEqual(john.id, id) info = john.info self.assertIsInstance(info, db_field_different) @@ -520,10 +492,99 @@ class something_silly_2(UserType): def test_set_udt_fields(self): # PYTHON-502 - class User(UserType): - age = columns.Integer() - name = columns.Text() u = User() u.age = 20 self.assertEqual(20, u.age) + + def test_default_values(self): + """ + Test that default types are set on object creation for UDTs + + @since 3.7.0 + @jira_ticket PYTHON-606 + @expected_result Default values should be set. + + @test_category data_types:udt + """ + + class NestedUdt(UserType): + + test_id = columns.UUID(default=uuid4) + something = columns.Text() + default_text = columns.Text(default="default text") + + class OuterModel(Model): + + name = columns.Text(primary_key=True) + first_name = columns.Text() + nested = columns.List(columns.UserDefinedType(NestedUdt)) + simple = columns.UserDefinedType(NestedUdt) + + sync_table(OuterModel) + self.addCleanup(drop_table, OuterModel) + + t = OuterModel.create(name='test1') + t.nested = [NestedUdt(something='test')] + t.simple = NestedUdt(something="") + t.save() + self.assertIsNotNone(t.nested[0].test_id) + self.assertEqual(t.nested[0].default_text, "default text") + self.assertIsNotNone(t.simple.test_id) + self.assertEqual(t.simple.default_text, "default text") + + def test_udt_validate(self): + """ + Test to verify restrictions are honored and that validate is called + for each member of the UDT when an updated is attempted + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidate(UserType): + age = columns.Integer() + name = columns.Text(max_length=2) + + class UserModelValidate(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidate) + + sync_table(UserModelValidate) + self.addCleanup(drop_table, UserModelValidate) + + user = UserValidate(age=1, name="Robert") + item = UserModelValidate(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() + + def test_udt_validate_with_default(self): + """ + Test to verify restrictions are honored and that validate is called + on the default value + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result a validation error is arisen due to the name being + too long + + @test_category data_types:object_mapper + """ + class UserValidateDefault(UserType): + age = columns.Integer() + name = columns.Text(max_length=2, default="Robert") + + class UserModelValidateDefault(Model): + id = columns.Integer(primary_key=True) + info = columns.UserDefinedType(UserValidateDefault) + + sync_table(UserModelValidateDefault) + self.addCleanup(drop_table, UserModelValidateDefault) + + user = UserValidateDefault(age=1) + item = UserModelValidateDefault(id=1, info=user) + with self.assertRaises(ValidationError): + item.save() diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py index 242bffe12f..096417baac 100644 --- a/tests/integration/cqlengine/model/test_updates.py +++ b/tests/integration/cqlengine/model/test_updates.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,16 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch from uuid import uuid4 -from mock import patch from cassandra.cqlengine import ValidationError +from tests.integration import greaterthancass21 from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns from cassandra.cqlengine.management import sync_table, drop_table - +from cassandra.cqlengine.usertype import UserType class TestUpdateModel(Model): @@ -61,6 +64,17 @@ def test_update_model(self): self.assertEqual(m2.count, m1.count) self.assertEqual(m2.text, m0.text) + #This shouldn't raise a Validation error as the PR is not changing + m0.update(partition=m0.partition, cluster=m0.cluster) + + #Assert a ValidationError is risen if the PR changes + with self.assertRaises(ValidationError): + m0.update(partition=m0.partition, cluster=20) + + # Assert a ValidationError is risen if the columns doesn't exist + with self.assertRaises(ValidationError): + m0.update(invalid_column=20) + def test_update_values(self): """ tests calling update on models with values passed in """ m0 = TestUpdateModel.create(count=5, text='monkey') @@ -79,8 +93,8 @@ def test_update_values(self): self.assertEqual(m2.count, m1.count) self.assertEqual(m2.text, m0.text) - def test_noop_model_update(self): - """ tests that calling update on a model with no changes will do nothing. """ + def test_noop_model_direct_update(self): + """ Tests that calling update on a model with no changes will do nothing. """ m0 = TestUpdateModel.create(count=5, text='monkey') with patch.object(self.session, 'execute') as execute: @@ -91,6 +105,38 @@ def test_noop_model_update(self): m0.update(count=5) assert execute.call_count == 0 + with patch.object(self.session, 'execute') as execute: + m0.update(partition=m0.partition) + + with patch.object(self.session, 'execute') as execute: + m0.update(cluster=m0.cluster) + + def test_noop_model_assignation_update(self): + """ Tests that assigning the same value on a model will do nothing. """ + # Create object and fetch it back to eliminate any hidden variable + # cache effect. + m0 = TestUpdateModel.create(count=5, text='monkey') + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + + with patch.object(self.session, 'execute') as execute: + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.count = 5 + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.partition = m0.partition + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.cluster = m0.cluster + m1.save() + assert execute.call_count == 0 + def test_invalid_update_kwarg(self): """ tests that passing in a kwarg to the update method that isn't a column will fail """ m0 = TestUpdateModel.create(count=5, text='monkey') @@ -102,3 +148,228 @@ def test_primary_key_update_failure(self): m0 = TestUpdateModel.create(count=5, text='monkey') with self.assertRaises(ValidationError): m0.update(partition=uuid4()) + + +class UDT(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer) + dummy_udt = columns.Integer(default=42) + time_col = columns.Time() + + +class ModelWithDefault(Model): + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2:2})) + + +class UDTWithDefault(UserType): + age = columns.Integer() + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy_udt = columns.Integer(default=42) + + +class ModelWithDefaultCollection(Model): + id = columns.Integer(primary_key=True) + mf = columns.Map(columns.Integer, columns.Integer, default={2:2}) + dummy = columns.Integer(default=42) + udt = columns.UserDefinedType(UDT) + udt_default = columns.UserDefinedType(UDT, default=UDT(age=1, mf={2: 2})) + +@greaterthancass21 +class ModelWithDefaultTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + cls.udt_default = UDT(age=1, mf={2:2}, dummy_udt=42) + + def setUp(self): + sync_table(ModelWithDefault) + sync_table(ModelWithDefaultCollection) + + def tearDown(self): + drop_table(ModelWithDefault) + drop_table(ModelWithDefaultCollection) + + def test_value_override_with_default(self): + """ + Updating a row with a new Model instance shouldn't set columns to defaults + + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should not change + + @test_category object_mapper + """ + first_udt = UDT(age=1, mf={2:2}, dummy_udt=0) + initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt) + initial.save() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt}) + + second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12) + second = ModelWithDefault(id=1) + second.update(mf={0: 1}, udt=second_udt) + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt}) + + def test_value_is_written_if_is_default(self): + """ + Check if the we try to update with the default value, the update + happens correctly + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should be updated + :return: + """ + initial = ModelWithDefault(id=1) + initial.mf = {0: 0} + initial.dummy = 42 + initial.udt_default = self.udt_default + initial.update() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.udt_default}) + + def test_null_update_is_respected(self): + """ + Check if the we try to update with None under particular + circumstances, it works correctly + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result column value should be updated to None + + @test_category object_mapper + :return: + """ + ModelWithDefault.create(id=1, mf={0: 0}).save() + + q = ModelWithDefault.objects.all().allow_filtering() + obj = q.filter(id=1).get() + + updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None) + obj.update(dummy=None, udt_default=updated_udt) + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt}) + + def test_only_set_values_is_updated(self): + """ + Test the updates work as expected when an object is deleted + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the non updated column is None and the + updated column has the set value + + @test_category object_mapper + """ + + ModelWithDefault.create(id=1, mf={1: 1}, dummy=1).save() + + item = ModelWithDefault.filter(id=1).first() + ModelWithDefault.objects(id=1).delete() + item.mf = {1: 2} + udt, udt_default = UDT(age=1, mf={2:3}), UDT(age=1, mf={2:3}) + item.udt, item.udt_default = udt, udt_default + item.save() + + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": udt_default}) + + def test_collections(self): + """ + Test the updates work as expected on Map objects + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the row is updated when the Map object is + reduced + + @test_category object_mapper + """ + udt, udt_default = UDT(age=1, mf={1: 1, 2: 1}), UDT(age=1, mf={1: 1, 2: 1}) + + ModelWithDefault.create(id=1, mf={1: 1, 2: 1}, dummy=1, udt=udt, udt_default=udt_default).save() + item = ModelWithDefault.filter(id=1).first() + + udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1}) + item.update(mf={2:1}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefault.get()._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default}) + + def test_collection_with_default(self): + """ + Test the updates work as expected when an object is deleted + @since 3.9 + @jira_ticket PYTHON-657 + @expected_result the non updated column is None and the + updated column has the set value + + @test_category object_mapper + """ + sync_table(ModelWithDefaultCollection) + + udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6}) + + item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save() + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5}) + item.update(mf={2: 2}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None) + expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={}) + item.update(mf=None, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default}) + + udt_default = UDT(age=1, mf={2:2}, dummy_udt=42) + item = ModelWithDefaultCollection.create(id=2, dummy=2) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default}) + + udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6}) + item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + + item.update(udt_default=None) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": None}) + + udt_default = UDT(age=1, mf={2:2}) + item.update(udt_default=udt_default) + self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), + {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + + + def test_udt_to_python(self): + """ + Test the to_python and to_database are correctly called on UDTs + @since 3.10 + @jira_ticket PYTHON-743 + @expected_result the int value is correctly converted to utils.Time + and written to C* + + @test_category object_mapper + """ + item = ModelWithDefault(id=1) + item.save() + + # We update time_col this way because we want to hit + # the to_python method from UserDefinedType, otherwise to_python + # would be called in UDT.__init__ + user_to_update = UDT() + user_to_update.time_col = 10 + + item.update(udt=user_to_update) + + udt, udt_default = UDT(time_col=10), UDT(age=1, mf={2:2}) + self.assertEqual(ModelWithDefault.objects.get(id=1)._as_dict(), + {'id': 1, 'dummy': 42, 'mf': {}, "udt": udt, "udt_default": udt_default}) diff --git a/tests/integration/cqlengine/model/test_value_lists.py b/tests/integration/cqlengine/model/test_value_lists.py index 8bd9b218f5..a6fc0b25f3 100644 --- a/tests/integration/cqlengine/model/test_value_lists.py +++ b/tests/integration/cqlengine/model/test_value_lists.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/operators/__init__.py b/tests/integration/cqlengine/operators/__init__.py index 1c7af46e71..9d1d6564dc 100644 --- a/tests/integration/cqlengine/operators/__init__.py +++ b/tests/integration/cqlengine/operators/__init__.py @@ -1,13 +1,22 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from cassandra.cqlengine.operators import BaseWhereOperator + + +def check_lookup(test_case, symbol, expected): + op = BaseWhereOperator.get_operator(symbol) + test_case.assertEqual(op, expected) diff --git a/tests/integration/cqlengine/operators/test_where_operators.py b/tests/integration/cqlengine/operators/test_where_operators.py index 72ff2f6263..808e14df04 100644 --- a/tests/integration/cqlengine/operators/test_where_operators.py +++ b/tests/integration/cqlengine/operators/test_where_operators.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,36 +14,98 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import TestCase +import unittest + from cassandra.cqlengine.operators import * -import six +from uuid import uuid4 + +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.operators import IsNotNullOperator +from cassandra.cqlengine.statements import IsNotNull +from cassandra import InvalidRequest + +from tests.integration.cqlengine.base import TestQueryUpdateModel, BaseCassEngTestCase +from tests.integration.cqlengine.operators import check_lookup +from tests.integration import greaterthanorequalcass30 -class TestWhereOperators(TestCase): + +class TestWhereOperators(unittest.TestCase): def test_symbol_lookup(self): """ tests where symbols are looked up properly """ - def check_lookup(symbol, expected): - op = BaseWhereOperator.get_operator(symbol) - self.assertEqual(op, expected) - - check_lookup('EQ', EqualsOperator) - check_lookup('IN', InOperator) - check_lookup('GT', GreaterThanOperator) - check_lookup('GTE', GreaterThanOrEqualOperator) - check_lookup('LT', LessThanOperator) - check_lookup('LTE', LessThanOrEqualOperator) - check_lookup('CONTAINS', ContainsOperator) + check_lookup(self, 'EQ', EqualsOperator) + check_lookup(self, 'NE', NotEqualsOperator) + check_lookup(self, 'IN', InOperator) + check_lookup(self, 'GT', GreaterThanOperator) + check_lookup(self, 'GTE', GreaterThanOrEqualOperator) + check_lookup(self, 'LT', LessThanOperator) + check_lookup(self, 'LTE', LessThanOrEqualOperator) + check_lookup(self, 'CONTAINS', ContainsOperator) + check_lookup(self, 'LIKE', LikeOperator) def test_operator_rendering(self): """ tests symbols are rendered properly """ - self.assertEqual("=", six.text_type(EqualsOperator())) - self.assertEqual("IN", six.text_type(InOperator())) - self.assertEqual(">", six.text_type(GreaterThanOperator())) - self.assertEqual(">=", six.text_type(GreaterThanOrEqualOperator())) - self.assertEqual("<", six.text_type(LessThanOperator())) - self.assertEqual("<=", six.text_type(LessThanOrEqualOperator())) - self.assertEqual("CONTAINS", six.text_type(ContainsOperator())) + self.assertEqual("=", str(EqualsOperator())) + self.assertEqual("!=", str(NotEqualsOperator())) + self.assertEqual("IN", str(InOperator())) + self.assertEqual(">", str(GreaterThanOperator())) + self.assertEqual(">=", str(GreaterThanOrEqualOperator())) + self.assertEqual("<", str(LessThanOperator())) + self.assertEqual("<=", str(LessThanOrEqualOperator())) + self.assertEqual("CONTAINS", str(ContainsOperator())) + self.assertEqual("LIKE", str(LikeOperator())) + + +class TestIsNotNull(BaseCassEngTestCase): + def test_is_not_null_to_cql(self): + """ + Verify that IsNotNull is converted correctly to CQL + + @since 2.5 + @jira_ticket PYTHON-968 + @expected_result the strings match + + @test_category cqlengine + """ + + check_lookup(self, 'IS NOT NULL', IsNotNullOperator) + + # The * is not expanded because there are no referred fields + self.assertEqual( + str(TestQueryUpdateModel.filter(IsNotNull("text")).limit(2)), + 'SELECT * FROM cqlengine_test.test_query_update_model WHERE "text" IS NOT NULL LIMIT 2' + ) + + # We already know partition so cqlengine doesn't query for it + self.assertEqual( + str(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())), + ('SELECT "cluster", "count", "text", "text_set", ' + '"text_list", "text_map" FROM cqlengine_test.test_query_update_model ' + 'WHERE "text" IS NOT NULL AND "partition" = %(0)s LIMIT 10000') + ) + + @greaterthanorequalcass30 + def test_is_not_null_execution(self): + """ + Verify that CQL statements have correct syntax when executed + If we wanted them to return something meaningful and not a InvalidRequest + we'd have to create an index in search for the column we are using + IsNotNull + + @since 2.5 + @jira_ticket PYTHON-968 + @expected_result InvalidRequest is arisen + + @test_category cqlengine + """ + sync_table(TestQueryUpdateModel) + self.addCleanup(drop_table, TestQueryUpdateModel) + # Raises InvalidRequest instead of dse.protocol.SyntaxException + with self.assertRaises(InvalidRequest): + list(TestQueryUpdateModel.filter(IsNotNull("text"))) + with self.assertRaises(InvalidRequest): + list(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())) diff --git a/tests/integration/cqlengine/query/__init__.py b/tests/integration/cqlengine/query/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/integration/cqlengine/query/__init__.py +++ b/tests/integration/cqlengine/query/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/query/test_batch_query.py b/tests/integration/cqlengine/query/test_batch_query.py index 126ad21e36..d3cddc0c7e 100644 --- a/tests/integration/cqlengine/query/test_batch_query.py +++ b/tests/integration/cqlengine/query/test_batch_query.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock +from unittest import mock from cassandra.cqlengine import columns from cassandra.cqlengine.connection import NOT_SET @@ -20,7 +22,10 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine.query import BatchQuery, DMLQuery from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import execute_count from cassandra.cluster import Session +from cassandra.query import BatchType as cassandra_BatchType +from cassandra.cqlengine.query import BatchType as cqlengine_BatchType class TestMultiKeyModel(Model): @@ -36,6 +41,11 @@ class BatchQueryLogModel(Model): k = columns.Integer(primary_key=True) v = columns.Integer() + +class CounterBatchQueryModel(Model): + k = columns.Integer(primary_key=True) + v = columns.Counter() + class BatchQueryTests(BaseCassEngTestCase): @classmethod @@ -55,6 +65,7 @@ def setUp(self): for obj in TestMultiKeyModel.filter(partition=self.pkey): obj.delete() + @execute_count(3) def test_insert_success_case(self): b = BatchQuery() @@ -67,6 +78,7 @@ def test_insert_success_case(self): TestMultiKeyModel.get(partition=self.pkey, cluster=2) + @execute_count(4) def test_update_success_case(self): inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') @@ -84,6 +96,7 @@ def test_update_success_case(self): inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) self.assertEqual(inst3.count, 4) + @execute_count(4) def test_delete_success_case(self): inst = TestMultiKeyModel.create(partition=self.pkey, cluster=2, count=3, text='4') @@ -99,6 +112,7 @@ def test_delete_success_case(self): with self.assertRaises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) + @execute_count(11) def test_context_manager(self): with BatchQuery() as b: @@ -112,6 +126,7 @@ def test_context_manager(self): for i in range(5): TestMultiKeyModel.get(partition=self.pkey, cluster=i) + @execute_count(9) def test_bulk_delete_success_case(self): for i in range(1): @@ -127,6 +142,7 @@ def test_bulk_delete_success_case(self): for m in TestMultiKeyModel.all(): m.delete() + @execute_count(0) def test_none_success_case(self): """ Tests that passing None into the batch call clears any batch object """ b = BatchQuery() @@ -137,6 +153,7 @@ def test_none_success_case(self): q = q.batch(None) self.assertIsNone(q._batch) + @execute_count(0) def test_dml_none_success_case(self): """ Tests that passing None into the batch call clears any batch object """ b = BatchQuery() @@ -147,6 +164,7 @@ def test_dml_none_success_case(self): q.batch(None) self.assertIsNone(q._batch) + @execute_count(3) def test_batch_execute_on_exception_succeeds(self): # makes sure if execute_on_exception == True we still apply the batch drop_table(BatchQueryLogModel) @@ -166,6 +184,7 @@ def test_batch_execute_on_exception_succeeds(self): # should be 1 because the batch should execute self.assertEqual(1, len(obj)) + @execute_count(2) def test_batch_execute_on_exception_skips_if_not_specified(self): # makes sure if execute_on_exception == True we still apply the batch drop_table(BatchQueryLogModel) @@ -186,14 +205,89 @@ def test_batch_execute_on_exception_skips_if_not_specified(self): # should be 0 because the batch should not execute self.assertEqual(0, len(obj)) + @execute_count(1) def test_batch_execute_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: with BatchQuery(timeout=1) as b: BatchQueryLogModel.batch(b).create(k=2, v=2) self.assertEqual(mock_execute.call_args[-1]['timeout'], 1) + @execute_count(1) def test_batch_execute_no_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: with BatchQuery() as b: BatchQueryLogModel.batch(b).create(k=2, v=2) self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + + +class BatchTypeQueryTests(BaseCassEngTestCase): + def setUp(self): + sync_table(TestMultiKeyModel) + sync_table(CounterBatchQueryModel) + + def tearDown(self): + drop_table(TestMultiKeyModel) + drop_table(CounterBatchQueryModel) + + @execute_count(6) + def test_cassandra_batch_type(self): + """ + Tests the different types of `class: cassandra.query.BatchType` + + @since 3.13 + @jira_ticket PYTHON-88 + @expected_result batch query succeeds and the results + are correctly readen + + @test_category query + """ + with BatchQuery(batch_type=cassandra_BatchType.UNLOGGED) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + with BatchQuery(batch_type=cassandra_BatchType.COUNTER) as b: + CounterBatchQueryModel.batch(b).create(k=1, v=1) + CounterBatchQueryModel.batch(b).create(k=1, v=2) + CounterBatchQueryModel.batch(b).create(k=1, v=10) + + obj = CounterBatchQueryModel.objects(k=1) + self.assertEqual(1, len(obj)) + self.assertEqual(obj[0].v, 13) + + with BatchQuery(batch_type=cassandra_BatchType.LOGGED) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + @execute_count(4) + def test_cqlengine_batch_type(self): + """ + Tests the different types of `class: cassandra.cqlengine.query.BatchType` + + @since 3.13 + @jira_ticket PYTHON-88 + @expected_result batch query succeeds and the results + are correctly readen + + @test_category query + """ + with BatchQuery(batch_type=cqlengine_BatchType.Unlogged) as b: + TestMultiKeyModel.batch(b).create(partition=1, cluster=1) + TestMultiKeyModel.batch(b).create(partition=1, cluster=2) + + obj = TestMultiKeyModel.objects(partition=1) + self.assertEqual(2, len(obj)) + + with BatchQuery(batch_type=cqlengine_BatchType.Counter) as b: + CounterBatchQueryModel.batch(b).create(k=1, v=1) + CounterBatchQueryModel.batch(b).create(k=1, v=2) + CounterBatchQueryModel.batch(b).create(k=1, v=10) + + obj = CounterBatchQueryModel.objects(k=1) + self.assertEqual(1, len(obj)) + self.assertEqual(obj[0].v, 13) diff --git a/tests/integration/cqlengine/query/test_datetime_queries.py b/tests/integration/cqlengine/query/test_datetime_queries.py index 118d74ddcf..8225b2d9f3 100644 --- a/tests/integration/cqlengine/query/test_datetime_queries.py +++ b/tests/integration/cqlengine/query/test_datetime_queries.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,15 +22,17 @@ from cassandra.cqlengine.management import sync_table from cassandra.cqlengine.management import drop_table -from cassandra.cqlengine.models import Model, ModelException +from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns -from cassandra.cqlengine import query +from tests.integration.cqlengine import execute_count + class DateTimeQueryTestModel(Model): - user = columns.Integer(primary_key=True) - day = columns.DateTime(primary_key=True) - data = columns.Text() + user = columns.Integer(primary_key=True) + day = columns.DateTime(primary_key=True) + data = columns.Text() + class TestDateTimeQueries(BaseCassEngTestCase): @@ -46,12 +50,12 @@ def setUpClass(cls): data=str(uuid4()) ) - @classmethod def tearDownClass(cls): super(TestDateTimeQueries, cls).tearDownClass() drop_table(DateTimeQueryTestModel) + @execute_count(1) def test_range_query(self): """ Tests that loading from a range of dates works properly """ start = datetime(*self.base_date.timetuple()[:3]) @@ -60,6 +64,7 @@ def test_range_query(self): results = DateTimeQueryTestModel.filter(user=0, day__gte=start, day__lt=end) assert len(results) == 3 + @execute_count(3) def test_datetime_precision(self): """ Tests that millisecond resolution is preserved when saving datetime objects """ now = datetime.now() diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 3b51d1d216..b6ba23a2e1 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra import ConsistencyLevel from cassandra.cqlengine import operators @@ -25,7 +24,7 @@ from cassandra.concurrent import execute_concurrent_with_args from cassandra.cqlengine import models -from tests.integration.cqlengine import setup_connection +from tests.integration.cqlengine import setup_connection, execute_count from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.query.test_queryset import BaseQuerySetUsage @@ -134,6 +133,7 @@ def setUpClass(cls): cls.keyspace = NamedKeyspace(ks) cls.table = cls.keyspace.table(tn) + @execute_count(2) def test_count(self): """ Tests that adding filtering statements affects the count query as expected """ assert self.table.objects.count() == 12 @@ -141,6 +141,7 @@ def test_count(self): q = self.table.objects(test_id=0) assert q.count() == 4 + @execute_count(2) def test_query_expression_count(self): """ Tests that adding query statements affects the count query as expected """ assert self.table.objects.count() == 12 @@ -148,6 +149,7 @@ def test_query_expression_count(self): q = self.table.objects(self.table.column('test_id') == 0) assert q.count() == 4 + @execute_count(3) def test_iteration(self): """ Tests that iterating over a query set pulls back all of the expected results """ q = self.table.objects(test_id=0) @@ -181,6 +183,7 @@ def test_iteration(self): compare_set.remove(val) assert len(compare_set) == 0 + @execute_count(2) def test_multiple_iterations_work_properly(self): """ Tests that iterating over a query set more than once works """ # test with both the filtering method and the query method @@ -201,6 +204,7 @@ def test_multiple_iterations_work_properly(self): compare_set.remove(val) assert len(compare_set) == 0 + @execute_count(2) def test_multiple_iterators_are_isolated(self): """ tests that the use of one iterator does not affect the behavior of another @@ -214,6 +218,7 @@ def test_multiple_iterators_are_isolated(self): assert next(iter1).attempt_id == attempt_id assert next(iter2).attempt_id == attempt_id + @execute_count(3) def test_get_success_case(self): """ Tests that the .get() method works on new and existing querysets @@ -235,6 +240,7 @@ def test_get_success_case(self): assert m.test_id == 0 assert m.attempt_id == 0 + @execute_count(3) def test_query_expression_get_success_case(self): """ Tests that the .get() method works on new and existing querysets @@ -256,6 +262,7 @@ def test_query_expression_get_success_case(self): assert m.test_id == 0 assert m.attempt_id == 0 + @execute_count(1) def test_get_doesnotexist_exception(self): """ Tests that get calls that don't return a result raises a DoesNotExist error @@ -263,6 +270,7 @@ def test_get_doesnotexist_exception(self): with self.assertRaises(self.table.DoesNotExist): self.table.objects.get(test_id=100) + @execute_count(1) def test_get_multipleobjects_exception(self): """ Tests that get calls that return multiple results raise a MultipleObjectsReturned error @@ -282,10 +290,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): models.DEFAULT_KEYSPACE = cls.default_keyspace - setup_connection(models.DEFAULT_KEYSPACE) super(TestNamedWithMV, cls).tearDownClass() @greaterthanorequalcass30 + @execute_count(5) def test_named_table_with_mv(self): """ Test NamedTable access to materialized views @@ -326,13 +334,13 @@ def test_named_table_with_mv(self): SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC)""".format(ks) + WITH CLUSTERING ORDER BY (score DESC, user DESC, year DESC, month DESC, day DESC)""".format(ks) self.session.execute(create_mv_alltime) # Populate the base table with data prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(ks)) - parameters = {('pcmanus', 'Coup', 2015, 5, 1, 4000), + parameters = (('pcmanus', 'Coup', 2015, 5, 1, 4000), ('jbellis', 'Coup', 2015, 5, 3, 1750), ('yukim', 'Coup', 2015, 5, 3, 2250), ('tjake', 'Coup', 2015, 5, 3, 500), @@ -343,7 +351,7 @@ def test_named_table_with_mv(self): ('jbellis', 'Coup', 2015, 6, 20, 3500), ('jbellis', 'Checkers', 2015, 6, 20, 1200), ('jbellis', 'Chess', 2015, 6, 21, 3500), - ('pcmanus', 'Chess', 2015, 1, 25, 3200)} + ('pcmanus', 'Chess', 2015, 1, 25, 3200)) prepared_insert.consistency_level = ConsistencyLevel.ALL execute_concurrent_with_args(self.session, prepared_insert, parameters) diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index af10435ca4..8f0dae06e7 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,10 +21,13 @@ from cassandra.cqlengine import query from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.models import Model +from cassandra.cqlengine.named import NamedTable from cassandra.cqlengine.operators import EqualsOperator from cassandra.cqlengine.statements import WhereClause - +from tests.integration.cqlengine import DEFAULT_KEYSPACE from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine import execute_count + class TestQuerySetOperation(BaseCassEngTestCase): @@ -54,7 +59,7 @@ def test_mintimeuuid_function(self): class TokenTestModel(Model): - + __table_name__ = "token_test_model" key = columns.Integer(primary_key=True) val = columns.Integer() @@ -69,12 +74,13 @@ def tearDown(self): super(TestTokenFunction, self).tearDown() drop_table(TokenTestModel) + @execute_count(15) def test_token_function(self): """ Tests that token functions work properly """ - assert TokenTestModel.objects().count() == 0 + assert TokenTestModel.objects.count() == 0 for i in range(10): TokenTestModel.create(key=i, val=i) - assert TokenTestModel.objects().count() == 10 + assert TokenTestModel.objects.count() == 10 seen_keys = set() last_token = None for instance in TokenTestModel.objects().limit(5): @@ -87,6 +93,11 @@ def test_token_function(self): assert len(seen_keys) == 10 assert all([i in seen_keys for i in range(10)]) + # pk__token equality + r = TokenTestModel.objects(pk__token=functions.Token(last_token)) + self.assertEqual(len(r), 1) + r.all() # Attempt to obtain queryset for results. This has thrown an exception in the past + def test_compound_pk_token_function(self): class TestModel(Model): @@ -124,3 +135,27 @@ class TestModel(Model): # The # of arguments to Token must match the # of partition keys func = functions.Token('a') self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=func) + + @execute_count(7) + def test_named_table_pk_token_function(self): + """ + Test to ensure that token function work with named tables. + + @since 3.2 + @jira_ticket PYTHON-272 + @expected_result partition key token functions should all for pagination. Prior to Python-272 + this would fail with an AttributeError + + @test_category object_mapper + """ + + for i in range(5): + TokenTestModel.create(key=i, val=i) + named = NamedTable(DEFAULT_KEYSPACE, TokenTestModel.__table_name__) + + query = named.all().limit(1) + first_page = list(query) + last = first_page[-1] + self.assertTrue(len(first_page) is 1) + next_page = list(query.filter(pk__token__gt=functions.Token(last.key))) + self.assertTrue(len(next_page) is 1) diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 0097e60490..d09d7eeb04 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,22 +15,18 @@ # limitations under the License. from __future__ import absolute_import -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from datetime import datetime -import time -from uuid import uuid1, uuid4 +from uuid import uuid4 +from packaging.version import Version import uuid -import sys from cassandra.cluster import Session from cassandra import InvalidRequest from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.connection import NOT_SET -import mock +from unittest import mock from cassandra.cqlengine import functions from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.models import Model @@ -41,9 +39,10 @@ from cassandra.cqlengine import statements from cassandra.cqlengine import operators from cassandra.util import uuid_from_time - from cassandra.cqlengine.connection import get_session -from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21 +from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ + greaterthanorequalcass30, TestCluster +from tests.integration.cqlengine import execute_count, DEFAULT_KEYSPACE class TzOffset(tzinfo): @@ -82,6 +81,14 @@ class IndexedTestModel(Model): test_result = columns.Integer(index=True) +class CustomIndexedTestModel(Model): + + test_id = columns.Integer(primary_key=True) + description = columns.Text(custom_index=True) + indexed = columns.Text(index=True) + data = columns.Text() + + class IndexedCollectionsTestModel(Model): test_id = columns.Integer(primary_key=True) @@ -106,6 +113,7 @@ class TestMultiClusteringModel(Model): class TestQuerySetOperation(BaseCassEngTestCase): + def test_query_filter_parsing(self): """ Tests the queryset filter method parses it's kwargs properly @@ -218,26 +226,106 @@ def test_queryset_with_distinct(self): query3 = TestModel.objects.distinct(['test_id', 'attempt_id']) self.assertEqual(len(query3._distinct_fields), 2) - def test_defining_only_and_defer_fails(self): + def test_defining_only_fields(self): """ - Tests that trying to add fields to either only or defer, or doing so more than once fails + Tests defining only fields + + @since 3.5 + @jira_ticket PYTHON-560 + @expected_result deferred fields should not be returned + + @test_category object_mapper """ + # simple only definition + q = TestModel.objects.only(['attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + + with self.assertRaises(query.QueryException): + TestModel.objects.only(['nonexistent_field']) + + # Cannot define more than once only fields + with self.assertRaises(query.QueryException): + TestModel.objects.only(['description']).only(['attempt_id']) - def test_defining_only_or_defer_on_nonexistant_fields_fails(self): + # only with defer fields + q = TestModel.objects.only(['attempt_id', 'description']) + q = q.defer(['description']) + self.assertEqual(q._select_fields(), ['attempt_id']) + + # Eliminate all results confirm exception is thrown + q = TestModel.objects.only(['description']) + q = q.defer(['description']) + with self.assertRaises(query.QueryException): + q._select_fields() + + q = TestModel.objects.filter(test_id=0).only(['test_id', 'attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + + # no fields to select + with self.assertRaises(query.QueryException): + q = TestModel.objects.only(['test_id']).defer(['test_id']) + q._select_fields() + + with self.assertRaises(query.QueryException): + q = TestModel.objects.filter(test_id=0).only(['test_id']) + q._select_fields() + + def test_defining_defer_fields(self): """ - Tests that setting only or defer fields that don't exist raises an exception + Tests defining defer fields + + @since 3.5 + @jira_ticket PYTHON-560 + @jira_ticket PYTHON-599 + @expected_result deferred fields should not be returned + + @test_category object_mapper """ + # simple defer definition + q = TestModel.objects.defer(['attempt_id', 'description']) + self.assertEqual(q._select_fields(), ['test_id', 'expected_result', 'test_result']) + + with self.assertRaises(query.QueryException): + TestModel.objects.defer(['nonexistent_field']) + + # defer more than one + q = TestModel.objects.defer(['attempt_id', 'description']) + q = q.defer(['expected_result']) + self.assertEqual(q._select_fields(), ['test_id', 'test_result']) + + # defer with only + q = TestModel.objects.defer(['description', 'attempt_id']) + q = q.only(['description', 'test_id']) + self.assertEqual(q._select_fields(), ['test_id']) + + # Eliminate all results confirm exception is thrown + q = TestModel.objects.defer(['description', 'attempt_id']) + q = q.only(['description']) + with self.assertRaises(query.QueryException): + q._select_fields() + + # implicit defer + q = TestModel.objects.filter(test_id=0) + self.assertEqual(q._select_fields(), ['attempt_id', 'description', 'expected_result', 'test_result']) + + # when all fields are defered, it fallbacks select the partition keys + q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) + self.assertEqual(q._select_fields(), ['test_id']) + class BaseQuerySetUsage(BaseCassEngTestCase): + @classmethod def setUpClass(cls): super(BaseQuerySetUsage, cls).setUpClass() drop_table(TestModel) drop_table(IndexedTestModel) + drop_table(CustomIndexedTestModel) sync_table(TestModel) sync_table(IndexedTestModel) + sync_table(CustomIndexedTestModel) sync_table(TestMultiClusteringModel) TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) @@ -273,7 +361,7 @@ def setUpClass(cls): IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, test_result=45) - if(CASSANDRA_VERSION >= '2.1'): + if CASSANDRA_VERSION >= Version('2.1'): drop_table(IndexedCollectionsTestModel) sync_table(IndexedCollectionsTestModel) IndexedCollectionsTestModel.objects.create(test_id=12, attempt_id=3, description='list12', expected_result=75, @@ -295,10 +383,13 @@ def tearDownClass(cls): super(BaseQuerySetUsage, cls).tearDownClass() drop_table(TestModel) drop_table(IndexedTestModel) + drop_table(CustomIndexedTestModel) drop_table(TestMultiClusteringModel) class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): + + @execute_count(2) def test_count(self): """ Tests that adding filtering statements affects the count query as expected """ assert TestModel.objects.count() == 12 @@ -306,6 +397,7 @@ def test_count(self): q = TestModel.objects(test_id=0) assert q.count() == 4 + @execute_count(2) def test_query_expression_count(self): """ Tests that adding query statements affects the count query as expected """ assert TestModel.objects.count() == 12 @@ -313,6 +405,7 @@ def test_query_expression_count(self): q = TestModel.objects(TestModel.test_id == 0) assert q.count() == 4 + @execute_count(3) def test_iteration(self): """ Tests that iterating over a query set pulls back all of the expected results """ q = TestModel.objects(test_id=0) @@ -346,6 +439,7 @@ def test_iteration(self): compare_set.remove(val) assert len(compare_set) == 0 + @execute_count(2) def test_multiple_iterations_work_properly(self): """ Tests that iterating over a query set more than once works """ # test with both the filtering method and the query method @@ -366,6 +460,7 @@ def test_multiple_iterations_work_properly(self): compare_set.remove(val) assert len(compare_set) == 0 + @execute_count(2) def test_multiple_iterators_are_isolated(self): """ tests that the use of one iterator does not affect the behavior of another @@ -379,6 +474,7 @@ def test_multiple_iterators_are_isolated(self): assert next(iter1).attempt_id == attempt_id assert next(iter2).attempt_id == attempt_id + @execute_count(3) def test_get_success_case(self): """ Tests that the .get() method works on new and existing querysets @@ -400,6 +496,7 @@ def test_get_success_case(self): assert m.test_id == 0 assert m.attempt_id == 0 + @execute_count(3) def test_query_expression_get_success_case(self): """ Tests that the .get() method works on new and existing querysets @@ -421,6 +518,7 @@ def test_query_expression_get_success_case(self): assert m.test_id == 0 assert m.attempt_id == 0 + @execute_count(1) def test_get_doesnotexist_exception(self): """ Tests that get calls that don't return a result raises a DoesNotExist error @@ -428,6 +526,7 @@ def test_get_doesnotexist_exception(self): with self.assertRaises(TestModel.DoesNotExist): TestModel.objects.get(test_id=100) + @execute_count(1) def test_get_multipleobjects_exception(self): """ Tests that get calls that return multiple results raise a MultipleObjectsReturned error @@ -439,7 +538,7 @@ def test_allow_filtering_flag(self): """ """ - +@execute_count(4) def test_non_quality_filtering(): class NonEqualityFilteringModel(Model): @@ -457,35 +556,41 @@ class NonEqualityFilteringModel(Model): NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) - qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() - num = qA.count() + qa = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() + num = qa.count() assert num == 1, num class TestQuerySetDistinct(BaseQuerySetUsage): + @execute_count(1) def test_distinct_without_parameter(self): q = TestModel.objects.distinct() self.assertEqual(len(q), 3) + @execute_count(1) def test_distinct_with_parameter(self): q = TestModel.objects.distinct(['test_id']) self.assertEqual(len(q), 3) + @execute_count(1) def test_distinct_with_filter(self): - q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1,2]) + q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) self.assertEqual(len(q), 2) + @execute_count(1) def test_distinct_with_non_partition(self): with self.assertRaises(InvalidRequest): q = TestModel.objects.distinct(['description']).filter(test_id__in=[1, 2]) len(q) + @execute_count(1) def test_zero_result(self): q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52]) self.assertEqual(len(q), 0) @greaterthancass21 + @execute_count(2) def test_distinct_with_explicit_count(self): q = TestModel.objects.distinct(['test_id']) self.assertEqual(q.count(), 3) @@ -495,7 +600,7 @@ def test_distinct_with_explicit_count(self): class TestQuerySetOrdering(BaseQuerySetUsage): - + @execute_count(2) def test_order_by_success_case(self): q = TestModel.objects(test_id=0).order_by('attempt_id') expected_order = [0, 1, 2, 3] @@ -510,20 +615,21 @@ def test_order_by_success_case(self): def test_ordering_by_non_second_primary_keys_fail(self): # kwarg filtering with self.assertRaises(query.QueryException): - q = TestModel.objects(test_id=0).order_by('test_id') + TestModel.objects(test_id=0).order_by('test_id') # kwarg filtering with self.assertRaises(query.QueryException): - q = TestModel.objects(TestModel.test_id == 0).order_by('test_id') + TestModel.objects(TestModel.test_id == 0).order_by('test_id') def test_ordering_by_non_primary_keys_fails(self): with self.assertRaises(query.QueryException): - q = TestModel.objects(test_id=0).order_by('description') + TestModel.objects(test_id=0).order_by('description') def test_ordering_on_indexed_columns_fails(self): with self.assertRaises(query.QueryException): - q = IndexedTestModel.objects(test_id=0).order_by('attempt_id') + IndexedTestModel.objects(test_id=0).order_by('attempt_id') + @execute_count(8) def test_ordering_on_multiple_clustering_columns(self): TestMultiClusteringModel.create(one=1, two=1, three=4) TestMultiClusteringModel.create(one=1, two=1, three=2) @@ -542,23 +648,28 @@ def test_ordering_on_multiple_clustering_columns(self): class TestQuerySetSlicing(BaseQuerySetUsage): + + @execute_count(1) def test_out_of_range_index_raises_error(self): q = TestModel.objects(test_id=0).order_by('attempt_id') with self.assertRaises(IndexError): q[10] + @execute_count(1) def test_array_indexing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') expected_order = [0, 1, 2, 3] for i in range(len(q)): assert q[i].attempt_id == expected_order[i] + @execute_count(1) def test_negative_indexing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') expected_order = [0, 1, 2, 3] assert q[-1].attempt_id == expected_order[-1] assert q[-2].attempt_id == expected_order[-2] + @execute_count(1) def test_slicing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') expected_order = [0, 1, 2, 3] @@ -569,6 +680,7 @@ def test_slicing_works_properly(self): for model, expect in zip(q[0:3:2], expected_order[0:3:2]): self.assertEqual(model.attempt_id, expect) + @execute_count(1) def test_negative_slicing(self): q = TestModel.objects(test_id=0).order_by('attempt_id') expected_order = [0, 1, 2, 3] @@ -590,6 +702,7 @@ def test_negative_slicing(self): class TestQuerySetValidation(BaseQuerySetUsage): + def test_primary_key_or_index_must_be_specified(self): """ Tests that queries that don't have an equals relation to a primary key or indexed field fail @@ -607,6 +720,7 @@ def test_primary_key_or_index_must_have_equal_relation_filter(self): list([i for i in q]) @greaterthancass20 + @execute_count(7) def test_indexed_field_can_be_queried(self): """ Tests that queries on an indexed field will work without any primary key relations specified @@ -632,8 +746,45 @@ def test_indexed_field_can_be_queried(self): q = IndexedCollectionsTestModel.objects.filter(test_map__contains=13) self.assertEqual(q.count(), 0) + def test_custom_indexed_field_can_be_queried(self): + """ + Tests that queries on an custom indexed field will work without any primary key relations specified + """ + + with self.assertRaises(query.QueryException): + list(CustomIndexedTestModel.objects.filter(data='test')) # not custom indexed + + # It should return InvalidRequest if target an indexed columns + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(indexed='test', data='test')) + + # It should return InvalidRequest if target an indexed columns + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description='test', data='test')) + + # equals operator, server error since there is no real index, but it passes + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description='test')) + + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + + # gte operator, server error since there is no real index, but it passes + # this can't work with a secondary index + with self.assertRaises(InvalidRequest): + list(CustomIndexedTestModel.objects.filter(description__gte='test')) + + with TestCluster().connect() as session: + session.execute("CREATE INDEX custom_index_cqlengine ON {}.{} (description)". + format(DEFAULT_KEYSPACE, CustomIndexedTestModel._table_name)) + + list(CustomIndexedTestModel.objects.filter(description='test')) + list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) + class TestQuerySetDelete(BaseQuerySetUsage): + + @execute_count(9) def test_delete(self): TestModel.objects.create(test_id=3, attempt_id=0, description='try9', expected_result=50, test_result=40) TestModel.objects.create(test_id=3, attempt_id=1, description='try10', expected_result=60, test_result=40) @@ -658,7 +809,8 @@ def test_delete_without_any_where_args(self): with self.assertRaises(query.QueryException): TestModel.objects(attempt_id=0).delete() - @unittest.skipIf(CASSANDRA_VERSION < '3.0', "range deletion was introduce in C* 3.0, currently running {0}".format(CASSANDRA_VERSION)) + @greaterthanorequalcass30 + @execute_count(18) def test_range_deletion(self): """ Tests that range deletion work as expected @@ -673,7 +825,7 @@ def test_range_deletion(self): TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete() self.assertEqual(5, len(TestMultiClusteringModel.objects.all())) - TestMultiClusteringModel.objects(one=1, two__in=[8,9]).delete() + TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete() self.assertEqual(3, len(TestMultiClusteringModel.objects.all())) TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete() @@ -698,6 +850,7 @@ def tearDownClass(cls): super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass() drop_table(TimeUUIDQueryModel) + @execute_count(7) def test_tzaware_datetime_support(self): """Test that using timezone aware datetime instances works with the MinTimeUUID/MaxTimeUUID functions. @@ -741,19 +894,16 @@ def test_tzaware_datetime_support(self): TimeUUIDQueryModel.partition == pk, TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))] + @execute_count(8) def test_success_case(self): """ Test that the min and max time uuid functions work as expected """ pk = uuid4() - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='1') - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='2') - time.sleep(0.2) - midpoint = datetime.utcnow() - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='3') - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') - time.sleep(0.2) + startpoint = datetime.utcnow() + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=1)), data='1') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=2)), data='2') + midpoint = startpoint + timedelta(seconds=3) + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=4)), data='3') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=5)), data='4') # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) @@ -791,19 +941,68 @@ def test_success_case(self): class TestInOperator(BaseQuerySetUsage): + @execute_count(1) def test_kwarg_success_case(self): """ Tests the in operator works with the kwarg query method """ q = TestModel.filter(test_id__in=[0, 1]) assert q.count() == 8 + @execute_count(1) def test_query_expression_success_case(self): """ Tests the in operator works with the query expression query method """ q = TestModel.filter(TestModel.test_id.in_([0, 1])) assert q.count() == 8 + @execute_count(5) + def test_bool(self): + """ + Adding coverage to cqlengine for bool types. + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result bool results should be filtered appropriately + + @test_category object_mapper + """ + class bool_model(Model): + k = columns.Integer(primary_key=True) + b = columns.Boolean(primary_key=True) + v = columns.Integer(default=3) + sync_table(bool_model) + + bool_model.create(k=0, b=True) + bool_model.create(k=0, b=False) + self.assertEqual(len(bool_model.objects.all()), 2) + self.assertEqual(len(bool_model.objects.filter(k=0, b=True)), 1) + self.assertEqual(len(bool_model.objects.filter(k=0, b=False)), 1) + + @execute_count(3) + def test_bool_filter(self): + """ + Test to ensure that we don't translate boolean objects to String unnecessarily in filter clauses + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result We should not receive a server error + + @test_category object_mapper + """ + class bool_model2(Model): + k = columns.Boolean(primary_key=True) + b = columns.Integer(primary_key=True) + v = columns.Text() + drop_table(bool_model2) + sync_table(bool_model2) + + bool_model2.create(k=True, b=1, v='a') + bool_model2.create(k=False, b=1, v='b') + self.assertEqual(len(list(bool_model2.objects(k__in=(True, False)))), 2) + @greaterthancass20 class TestContainsOperator(BaseQuerySetUsage): + + @execute_count(6) def test_kwarg_success_case(self): """ Tests the CONTAINS operator works with the kwarg query method """ q = IndexedCollectionsTestModel.filter(test_list__contains=1) @@ -834,6 +1033,7 @@ def test_kwarg_success_case(self): q = IndexedCollectionsTestModel.filter(test_map_no_index__contains=1) self.assertEqual(q.count(), 0) + @execute_count(6) def test_query_expression_success_case(self): """ Tests the CONTAINS operator works with the query expression query method """ q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(1)) @@ -866,6 +1066,8 @@ def test_query_expression_success_case(self): class TestValuesList(BaseQuerySetUsage): + + @execute_count(2) def test_values_list(self): q = TestModel.objects.filter(test_id=0, attempt_id=1) item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first() @@ -876,13 +1078,15 @@ def test_values_list(self): class TestObjectsProperty(BaseQuerySetUsage): + @execute_count(1) def test_objects_property_returns_fresh_queryset(self): assert TestModel.objects._result_cache is None - len(TestModel.objects) # evaluate queryset + len(TestModel.objects) # evaluate queryset assert TestModel.objects._result_cache is None class PageQueryTests(BaseCassEngTestCase): + @execute_count(3) def test_paged_result_handling(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest("Paging requires native protocol 2+, currently using: {0}".format(PROTOCOL_VERSION)) @@ -979,19 +1183,18 @@ class DBFieldModelMixed2(Model): class TestModelQueryWithDBField(BaseCassEngTestCase): - @classmethod - def setUpClass(cls): + def setUp(cls): super(TestModelQueryWithDBField, cls).setUpClass() cls.model_list = [DBFieldModel, DBFieldModelMixed1, DBFieldModelMixed2] for model in cls.model_list: sync_table(model) - @classmethod - def tearDownClass(cls): + def tearDown(cls): super(TestModelQueryWithDBField, cls).tearDownClass() for model in cls.model_list: drop_table(model) + @execute_count(33) def test_basic_crud(self): """ Tests creation update and delete of object model queries that are using db_field mappings. @@ -1028,6 +1231,7 @@ def test_basic_crud(self): i = model.objects(k0=i.k0, k1=i.k1).first() self.assertIsNone(i) + @execute_count(21) def test_slice(self): """ Tests slice queries for object models that are using db_field mapping @@ -1050,6 +1254,7 @@ def test_slice(self): self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count(), len(clustering_values[:-1])) self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count(), len(clustering_values[1:])) + @execute_count(15) def test_order(self): """ Tests order by queries for object models that are using db_field mapping @@ -1069,6 +1274,7 @@ def test_order(self): self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('c0').first().c0, clustering_values[0]) self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('-c0').first().c0, clustering_values[-1]) + @execute_count(15) def test_index(self): """ Tests queries using index fields for object models using db_field mapping @@ -1089,6 +1295,47 @@ def test_index(self): self.assertEqual(model.objects(k0=i.k0, k1=i.k1).count(), len(clustering_values)) self.assertEqual(model.objects(k0=i.k0, k1=i.k1, v1=0).count(), 1) + @execute_count(1) + def test_db_field_names_used(self): + """ + Tests to ensure that with generated cql update statements correctly utilize the db_field values. + + @since 3.2 + @jira_ticket PYTHON-530 + @expected_result resulting cql_statements will use the db_field values + + @test_category object_mapper + """ + + values = ('k0', 'k1', 'c0', 'v0', 'v1') + # Test QuerySet Path + b = BatchQuery() + DBFieldModel.objects(k0=1).batch(b).update( + v0=0, + v1=9, + ) + for value in values: + self.assertTrue(value not in str(b.queries[0])) + + # Test DML path + b2 = BatchQuery() + dml_field_model = DBFieldModel.create(k0=1, k1=5, c0=3, v0=4, v1=5) + dml_field_model.batch(b2).update( + v0=0, + v1=9, + ) + for value in values: + self.assertTrue(value not in str(b2.queries[0])) + + def test_db_field_value_list(self): + DBFieldModel.create(k0=0, k1=0, c0=0, v0=4, v1=5) + + self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._defer_fields, + {'a', 'c', 'b'}) + self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._only_fields, + ['c', 'd']) + + list(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')) class TestModelSmall(Model): @@ -1117,10 +1364,21 @@ def tearDownClass(cls): super(TestModelQueryWithFetchSize, cls).tearDownClass() drop_table(TestModelSmall) + @execute_count(19) def test_defaultFetchSize(self): + # Use smaller batch sizes to avoid hitting the max. We trigger an InvalidRequest + # response for Cassandra 4.1.x and 5.0.x if we just do the whole thing as one + # large batch. We're just using this to populate values for a test, however, + # so shifting to smaller batches should be fine. + for i in range(0, 5000, 500): + with BatchQuery() as b: + range_max = i + 500 + for j in range(i, range_max): + TestModelSmall.batch(b).create(test_id=j) with BatchQuery() as b: - for i in range(5100): + for i in range(5000, 5100): TestModelSmall.batch(b).create(test_id=i) + self.assertEqual(len(TestModelSmall.objects.fetch_size(1)), 5100) self.assertEqual(len(TestModelSmall.objects.fetch_size(500)), 5100) self.assertEqual(len(TestModelSmall.objects.fetch_size(4999)), 5100) @@ -1134,3 +1392,74 @@ def test_defaultFetchSize(self): TestModelSmall.objects.fetch_size(0) with self.assertRaises(QueryException): TestModelSmall.objects.fetch_size(-1) + + +class People(Model): + __table_name__ = "people" + last_name = columns.Text(primary_key=True, partition_key=True) + first_name = columns.Text(primary_key=True) + birthday = columns.DateTime() + + +class People2(Model): + __table_name__ = "people" + last_name = columns.Text(primary_key=True, partition_key=True) + first_name = columns.Text(primary_key=True) + middle_name = columns.Text() + birthday = columns.DateTime() + + +class TestModelQueryWithDifferedFeld(BaseCassEngTestCase): + """ + Tests that selects with filter will deffer population of known values until after the results are returned. + I.E. Instead of generating SELECT * FROM People WHERE last_name="Smith" It will generate + SELECT first_name, birthday FROM People WHERE last_name="Smith" + Where last_name 'smith' will populated post query + + @since 3.2 + @jira_ticket PYTHON-520 + @expected_result only needed fields are included in the query + + @test_category object_mapper + """ + @classmethod + def setUpClass(cls): + super(TestModelQueryWithDifferedFeld, cls).setUpClass() + sync_table(People) + + @classmethod + def tearDownClass(cls): + super(TestModelQueryWithDifferedFeld, cls).tearDownClass() + drop_table(People) + + @execute_count(8) + def test_defaultFetchSize(self): + # Populate Table + People.objects.create(last_name="Smith", first_name="John", birthday=datetime.now()) + People.objects.create(last_name="Bestwater", first_name="Alan", birthday=datetime.now()) + People.objects.create(last_name="Smith", first_name="Greg", birthday=datetime.now()) + People.objects.create(last_name="Smith", first_name="Adam", birthday=datetime.now()) + + # Check query constructions + expected_fields = ['first_name', 'birthday'] + self.assertEqual(People.filter(last_name="Smith")._select_fields(), expected_fields) + # Validate correct fields are fetched + smiths = list(People.filter(last_name="Smith")) + self.assertEqual(len(smiths), 3) + self.assertTrue(smiths[0].last_name is not None) + + # Modify table with new value + sync_table(People2) + + # populate new format + People2.objects.create(last_name="Smith", first_name="Chris", middle_name="Raymond", birthday=datetime.now()) + People2.objects.create(last_name="Smith", first_name="Andrew", middle_name="Micheal", birthday=datetime.now()) + + # validate query construction + expected_fields = ['first_name', 'middle_name', 'birthday'] + self.assertEqual(People2.filter(last_name="Smith")._select_fields(), expected_fields) + + # validate correct items are returneds + smiths = list(People2.filter(last_name="Smith")) + self.assertEqual(len(smiths), 5) + self.assertTrue(smiths[0].last_name is not None) diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py index afa3a7f096..b0b9155ea2 100644 --- a/tests/integration/cqlengine/query/test_updates.py +++ b/tests/integration/cqlengine/query/test_updates.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,19 +20,11 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine import columns -from tests.integration.cqlengine import is_prepend_reversed -from tests.integration.cqlengine.base import BaseCassEngTestCase - - -class TestQueryUpdateModel(Model): - partition = columns.UUID(primary_key=True, default=uuid4) - cluster = columns.Integer(primary_key=True) - count = columns.Integer(required=False) - text = columns.Text(required=False, index=True) - text_set = columns.Set(columns.Text, required=False) - text_list = columns.List(columns.Text, required=False) - text_map = columns.Map(columns.Text, columns.Text, required=False) +from tests.integration.cqlengine import is_prepend_reversed +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel +from tests.integration.cqlengine import execute_count +from tests.integration import greaterthancass20 class QueryUpdateTests(BaseCassEngTestCase): @@ -45,6 +39,7 @@ def tearDownClass(cls): super(QueryUpdateTests, cls).tearDownClass() drop_table(TestQueryUpdateModel) + @execute_count(8) def test_update_values(self): """ tests calling udpate on a queryset """ partition = uuid4() @@ -65,6 +60,7 @@ def test_update_values(self): self.assertEqual(row.count, 6 if i == 3 else i) self.assertEqual(row.text, str(i)) + @execute_count(6) def test_update_values_validation(self): """ tests calling udpate on models with values passed in """ partition = uuid4() @@ -91,6 +87,7 @@ def test_primary_key_update_failure(self): with self.assertRaises(ValidationError): TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000) + @execute_count(8) def test_null_update_deletes_column(self): """ setting a field to null in the update should issue a delete statement """ partition = uuid4() @@ -111,6 +108,7 @@ def test_null_update_deletes_column(self): self.assertEqual(row.count, i) self.assertEqual(row.text, None if i == 3 else str(i)) + @execute_count(9) def test_mixed_value_and_null_update(self): """ tests that updating a columns value, and removing another works properly """ partition = uuid4() @@ -131,9 +129,7 @@ def test_mixed_value_and_null_update(self): self.assertEqual(row.count, 6 if i == 3 else i) self.assertEqual(row.text, None if i == 3 else str(i)) - def test_counter_updates(self): - pass - + @execute_count(3) def test_set_add_updates(self): partition = uuid4() cluster = 1 @@ -144,6 +140,7 @@ def test_set_add_updates(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_set, set(("foo", "bar"))) + @execute_count(2) def test_set_add_updates_new_record(self): """ If the key doesn't exist yet, an update creates the record """ @@ -154,6 +151,7 @@ def test_set_add_updates_new_record(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_set, set(("bar",))) + @execute_count(3) def test_set_remove_updates(self): partition = uuid4() cluster = 1 @@ -165,6 +163,7 @@ def test_set_remove_updates(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_set, set(("baz",))) + @execute_count(3) def test_set_remove_new_record(self): """ Removing something not in the set should silently do nothing """ @@ -178,6 +177,7 @@ def test_set_remove_new_record(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_set, set(("foo",))) + @execute_count(3) def test_list_append_updates(self): partition = uuid4() cluster = 1 @@ -189,6 +189,7 @@ def test_list_append_updates(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_list, ["foo", "bar"]) + @execute_count(3) def test_list_prepend_updates(self): """ Prepend two things since order is reversed by default by CQL """ partition = uuid4() @@ -204,6 +205,7 @@ def test_list_prepend_updates(self): expected = (prepended[::-1] if is_prepend_reversed() else prepended) + original self.assertEqual(obj.text_list, expected) + @execute_count(3) def test_map_update_updates(self): """ Merge a dictionary into existing value """ partition = uuid4() @@ -217,6 +219,7 @@ def test_map_update_updates(self): obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'}) + @execute_count(3) def test_map_update_none_deletes_key(self): """ The CQL behavior is if you set a key in a map to null it deletes that key from the map. Test that this works with __update. @@ -231,3 +234,116 @@ def test_map_update_none_deletes_key(self): text_map__update={"bar": None}) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) self.assertEqual(obj.text_map, {"foo": '1'}) + + @greaterthancass20 + @execute_count(5) + def test_map_update_remove(self): + """ + Test that map item removal with update(__remove=...) works + + @jira_ticket PYTHON-688 + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, + cluster=cluster, + text_map={"foo": '1', "bar": '2'} + ) + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove={"bar"}, + text_map__update={"foz": '4', "foo": '2'} + ) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '2', "foz": '4'}) + + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove={"foo", "foz"} + ) + self.assertEqual( + TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster).text_map, + {} + ) + + def test_map_remove_rejects_non_sets(self): + """ + Map item removal requires a set to match the CQL API + + @jira_ticket PYTHON-688 + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, + cluster=cluster, + text_map={"foo": '1', "bar": '2'} + ) + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( + text_map__remove=["bar"] + ) + + @execute_count(3) + def test_an_extra_delete_is_not_sent(self): + """ + Test to ensure that an extra DELETE is not sent if an object is read + from the DB with a None value + + @since 3.9 + @jira_ticket PYTHON-719 + @expected_result only three queries are executed, the first one for + inserting the object, the second one for reading it, and the third + one for updating it + + @test_category object_mapper + """ + partition = uuid4() + cluster = 1 + + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster) + + obj = TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).first() + + self.assertFalse({k: v for (k, v) in obj._values.items() if v.deleted}) + + obj.text = 'foo' + obj.save() + #execute_count will check the execution count and + #assert no more calls than necessary where made + +class StaticDeleteModel(Model): + example_id = columns.Integer(partition_key=True, primary_key=True, default=uuid4) + example_static1 = columns.Integer(static=True) + example_static2 = columns.Integer(static=True) + example_clust = columns.Integer(primary_key=True) + + +class StaticDeleteTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(StaticDeleteTests, cls).setUpClass() + sync_table(StaticDeleteModel) + + @classmethod + def tearDownClass(cls): + super(StaticDeleteTests, cls).tearDownClass() + drop_table(StaticDeleteModel) + + def test_static_deletion(self): + """ + Test to ensure that cluster keys are not included when removing only static columns + + @since 3.6 + @jira_ticket PYTHON-608 + @expected_result Server should not throw an exception, and the static column should be deleted + + @test_category object_mapper + """ + StaticDeleteModel.create(example_id=5, example_clust=5, example_static2=1) + sdm = StaticDeleteModel.filter(example_id=5).first() + self.assertEqual(1, sdm.example_static2) + sdm.update(example_static2=None) + self.assertIsNone(sdm.example_static2) diff --git a/tests/integration/cqlengine/statements/__init__.py b/tests/integration/cqlengine/statements/__init__.py index 1c7af46e71..635f0d9e60 100644 --- a/tests/integration/cqlengine/statements/__init__.py +++ b/tests/integration/cqlengine/statements/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/statements/test_assignment_clauses.py b/tests/integration/cqlengine/statements/test_assignment_clauses.py index 4fc3a28374..c6d75a447e 100644 --- a/tests/integration/cqlengine/statements/test_assignment_clauses.py +++ b/tests/integration/cqlengine/statements/test_assignment_clauses.py @@ -1,20 +1,19 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra.cqlengine.statements import AssignmentClause, SetUpdateClause, ListUpdateClause, MapUpdateClause, MapDeleteClause, FieldDeleteClause, CounterUpdateClause diff --git a/tests/integration/cqlengine/statements/test_assignment_statement.py b/tests/integration/cqlengine/statements/test_assignment_statement.py deleted file mode 100644 index 9d5481d98c..0000000000 --- a/tests/integration/cqlengine/statements/test_assignment_statement.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2013-2016 DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from cassandra.cqlengine.statements import AssignmentStatement, StatementException - - -class AssignmentStatementTest(unittest.TestCase): - - def test_add_assignment_type_checking(self): - """ tests that only assignment clauses can be added to queries """ - stmt = AssignmentStatement('table', []) - with self.assertRaises(StatementException): - stmt.add_assignment_clause('x=5') diff --git a/tests/integration/cqlengine/statements/test_base_clause.py b/tests/integration/cqlengine/statements/test_base_clause.py index 14d98782ea..cbba1ae36e 100644 --- a/tests/integration/cqlengine/statements/test_base_clause.py +++ b/tests/integration/cqlengine/statements/test_base_clause.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/integration/cqlengine/statements/test_base_statement.py b/tests/integration/cqlengine/statements/test_base_statement.py index 1bda588eff..211d76cf5c 100644 --- a/tests/integration/cqlengine/statements/test_base_statement.py +++ b/tests/integration/cqlengine/statements/test_base_statement.py @@ -1,32 +1,38 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest + +from uuid import uuid4 from cassandra.query import FETCH_SIZE_UNSET -from cassandra.cqlengine.statements import BaseCQLStatement, StatementException +from cassandra.cqlengine.statements import BaseCQLStatement +from cassandra.cqlengine.management import sync_table, drop_table +from cassandra.cqlengine.statements import InsertStatement, UpdateStatement, SelectStatement, DeleteStatement, \ + WhereClause +from cassandra.cqlengine.operators import EqualsOperator, LikeOperator +from cassandra.cqlengine.columns import Column +from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel +from tests.integration.cqlengine import DEFAULT_KEYSPACE +from tests.integration import greaterthanorequalcass3_10, lessthandse69, TestCluster + +from cassandra.cqlengine.connection import execute -class BaseStatementTest(unittest.TestCase): - def test_where_clause_type_checking(self): - """ tests that only assignment clauses can be added to queries """ - stmt = BaseCQLStatement('table', []) - with self.assertRaises(StatementException): - stmt.add_where_clause('x=5') +class BaseStatementTest(unittest.TestCase): def test_fetch_size(self): """ tests that fetch_size is correctly set """ @@ -38,3 +44,114 @@ def test_fetch_size(self): stmt = BaseCQLStatement('table', None) self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + + +class ExecuteStatementTest(BaseCassEngTestCase): + text = "text_for_db" + + @classmethod + def setUpClass(cls): + super(ExecuteStatementTest, cls).setUpClass() + sync_table(TestQueryUpdateModel) + cls.table_name = '{0}.test_query_update_model'.format(DEFAULT_KEYSPACE) + + @classmethod + def tearDownClass(cls): + super(ExecuteStatementTest, cls).tearDownClass() + drop_table(TestQueryUpdateModel) + + def _verify_statement(self, original): + st = SelectStatement(self.table_name) + result = execute(st) + response = result[0] + + for assignment in original.assignments: + self.assertEqual(response[assignment.field], assignment.value) + self.assertEqual(len(response), 7) + + def test_insert_statement_execute(self): + """ + Test to verify the execution of BaseCQLStatements using connection.execute + + @since 3.10 + @jira_ticket PYTHON-505 + @expected_result inserts a row in C*, updates the rows and then deletes + all the rows using BaseCQLStatements + + @test_category data_types:object_mapper + """ + partition = uuid4() + cluster = 1 + self._insert_statement(partition, cluster) + + # Verifying update statement + where = [WhereClause('partition', EqualsOperator(), partition), + WhereClause('cluster', EqualsOperator(), cluster)] + + st = UpdateStatement(self.table_name, where=where) + st.add_assignment(Column(db_field='count'), 2) + st.add_assignment(Column(db_field='text'), "text_for_db_update") + st.add_assignment(Column(db_field='text_set'), set(("foo_update", "bar_update"))) + st.add_assignment(Column(db_field='text_list'), ["foo_update", "bar_update"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '3', "bar": '4'}) + + execute(st) + self._verify_statement(st) + + # Verifying delete statement + execute(DeleteStatement(self.table_name, where=where)) + self.assertEqual(TestQueryUpdateModel.objects.count(), 0) + + @greaterthanorequalcass3_10 + @lessthandse69 + def test_like_operator(self): + """ + Test to verify the like operator works appropriately + + @since 3.13 + @jira_ticket PYTHON-512 + @expected_result the expected row is read using LIKE + + @test_category data_types:object_mapper + """ + cluster = TestCluster() + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + session.execute("""CREATE CUSTOM INDEX text_index ON {} (text) + USING 'org.apache.cassandra.index.sasi.SASIIndex';""".format(self.table_name)) + self.addCleanup(session.execute, "DROP INDEX {}.text_index".format(DEFAULT_KEYSPACE)) + + partition = uuid4() + cluster = 1 + self._insert_statement(partition, cluster) + + ss = SelectStatement(self.table_name) + like_clause = "text_for_%" + ss.add_where(Column(db_field='text'), LikeOperator(), like_clause) + self.assertEqual(str(ss), + 'SELECT * FROM {} WHERE "text" LIKE %(0)s'.format(self.table_name)) + + result = execute(ss) + self.assertEqual(result[0]["text"], self.text) + + q = TestQueryUpdateModel.objects.filter(text__like=like_clause).allow_filtering() + self.assertEqual(q[0].text, self.text) + + q = TestQueryUpdateModel.objects.filter(text__like=like_clause) + self.assertEqual(q[0].text, self.text) + + def _insert_statement(self, partition, cluster): + # Verifying insert statement + st = InsertStatement(self.table_name) + st.add_assignment(Column(db_field='partition'), partition) + st.add_assignment(Column(db_field='cluster'), cluster) + + st.add_assignment(Column(db_field='count'), 1) + st.add_assignment(Column(db_field='text'), self.text) + st.add_assignment(Column(db_field='text_set'), set(("foo", "bar"))) + st.add_assignment(Column(db_field='text_list'), ["foo", "bar"]) + st.add_assignment(Column(db_field='text_map'), {"foo": '1', "bar": '2'}) + + execute(st) + self._verify_statement(st) diff --git a/tests/integration/cqlengine/statements/test_delete_statement.py b/tests/integration/cqlengine/statements/test_delete_statement.py index 6fb7c5317e..433fa759ac 100644 --- a/tests/integration/cqlengine/statements/test_delete_statement.py +++ b/tests/integration/cqlengine/statements/test_delete_statement.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +15,10 @@ # limitations under the License. from unittest import TestCase + +from cassandra.cqlengine.columns import Column from cassandra.cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause, ConditionalClause from cassandra.cqlengine.operators import * -import six class DeleteStatementTests(TestCase): @@ -29,57 +32,61 @@ def test_single_field_is_listified(self): def test_field_rendering(self): """ tests that fields are properly added to the select statement """ ds = DeleteStatement('table', ['f1', 'f2']) - self.assertTrue(six.text_type(ds).startswith('DELETE "f1", "f2"'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) def test_none_fields_rendering(self): """ tests that a '*' is added if no fields are passed in """ ds = DeleteStatement('table', None) - self.assertTrue(six.text_type(ds).startswith('DELETE FROM'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) def test_table_rendering(self): ds = DeleteStatement('table', None) - self.assertTrue(six.text_type(ds).startswith('DELETE FROM table'), six.text_type(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) def test_where_clause_rendering(self): ds = DeleteStatement('table', None) - ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) - self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s', six.text_type(ds)) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s', str(ds)) def test_context_update(self): ds = DeleteStatement('table', None) ds.add_field(MapDeleteClause('d', {1: 2}, {1: 2, 3: 4})) - ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') ds.update_context_id(7) - self.assertEqual(six.text_type(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') + self.assertEqual(str(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') self.assertEqual(ds.get_context(), {'7': 'b', '8': 3}) def test_context(self): ds = DeleteStatement('table', None) - ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') self.assertEqual(ds.get_context(), {'0': 'b'}) def test_range_deletion_rendering(self): ds = DeleteStatement('table', None) - ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) - ds.add_where_clause(WhereClause('created_at', GreaterThanOrEqualOperator(), '0')) - ds.add_where_clause(WhereClause('created_at', LessThanOrEqualOperator(), '10')) - self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', six.text_type(ds)) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0') + ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10') + self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', str(ds)) + + ds = DeleteStatement('table', None) + ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') + ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20']) + self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', str(ds)) ds = DeleteStatement('table', None) - ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) - ds.add_where_clause(WhereClause('created_at', InOperator(), ['0', '10', '20'])) - self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', six.text_type(ds)) + ds.add_where(Column(db_field='a'), NotEqualsOperator(), 'b') + self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" != %(0)s', str(ds)) def test_delete_conditional(self): where = [WhereClause('id', EqualsOperator(), 1)] conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')] ds = DeleteStatement('table', where=where, conditionals=conditionals) self.assertEqual(len(ds.conditionals), len(conditionals)) - self.assertEqual(six.text_type(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) + self.assertEqual(str(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds)) fields = ['one', 'two'] ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals) - self.assertEqual(six.text_type(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', six.text_type(ds)) + self.assertEqual(str(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds)) diff --git a/tests/integration/cqlengine/statements/test_insert_statement.py b/tests/integration/cqlengine/statements/test_insert_statement.py index dc6465e247..f3f6b4fd92 100644 --- a/tests/integration/cqlengine/statements/test_insert_statement.py +++ b/tests/integration/cqlengine/statements/test_insert_statement.py @@ -1,51 +1,44 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from cassandra.cqlengine.statements import InsertStatement, StatementException, AssignmentClause +from cassandra.cqlengine.columns import Column +from cassandra.cqlengine.statements import InsertStatement -import six class InsertStatementTests(unittest.TestCase): - def test_where_clause_failure(self): - """ tests that where clauses cannot be added to Insert statements """ - ist = InsertStatement('table', None) - with self.assertRaises(StatementException): - ist.add_where_clause('s') - def test_statement(self): ist = InsertStatement('table', None) - ist.add_assignment_clause(AssignmentClause('a', 'b')) - ist.add_assignment_clause(AssignmentClause('c', 'd')) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') self.assertEqual( - six.text_type(ist), + str(ist), 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' ) def test_context_update(self): ist = InsertStatement('table', None) - ist.add_assignment_clause(AssignmentClause('a', 'b')) - ist.add_assignment_clause(AssignmentClause('c', 'd')) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') ist.update_context_id(4) self.assertEqual( - six.text_type(ist), + str(ist), 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' ) ctx = ist.get_context() @@ -53,6 +46,6 @@ def test_context_update(self): def test_additional_rendering(self): ist = InsertStatement('table', ttl=60) - ist.add_assignment_clause(AssignmentClause('a', 'b')) - ist.add_assignment_clause(AssignmentClause('c', 'd')) - self.assertIn('USING TTL 60', six.text_type(ist)) + ist.add_assignment(Column(db_field='a'), 'b') + ist.add_assignment(Column(db_field='c'), 'd') + self.assertIn('USING TTL 60', str(ist)) diff --git a/tests/integration/cqlengine/statements/test_select_statement.py b/tests/integration/cqlengine/statements/test_select_statement.py index 6612333809..9478202786 100644 --- a/tests/integration/cqlengine/statements/test_select_statement.py +++ b/tests/integration/cqlengine/statements/test_select_statement.py @@ -1,24 +1,23 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest +from cassandra.cqlengine.columns import Column from cassandra.cqlengine.statements import SelectStatement, WhereClause from cassandra.cqlengine.operators import * -import six class SelectStatementTests(unittest.TestCase): @@ -30,52 +29,52 @@ def test_single_field_is_listified(self): def test_field_rendering(self): """ tests that fields are properly added to the select statement """ ss = SelectStatement('table', ['f1', 'f2']) - self.assertTrue(six.text_type(ss).startswith('SELECT "f1", "f2"'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) def test_none_fields_rendering(self): """ tests that a '*' is added if no fields are passed in """ ss = SelectStatement('table') - self.assertTrue(six.text_type(ss).startswith('SELECT *'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) def test_table_rendering(self): ss = SelectStatement('table') - self.assertTrue(six.text_type(ss).startswith('SELECT * FROM table'), six.text_type(ss)) + self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) def test_where_clause_rendering(self): ss = SelectStatement('table') - ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) - self.assertEqual(six.text_type(ss), 'SELECT * FROM table WHERE "a" = %(0)s', six.text_type(ss)) + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s', str(ss)) def test_count(self): ss = SelectStatement('table', count=True, limit=10, order_by='d') - ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) - self.assertEqual(six.text_type(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', six.text_type(ss)) - self.assertIn('LIMIT', six.text_type(ss)) - self.assertNotIn('ORDER', six.text_type(ss)) + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') + self.assertEqual(str(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', str(ss)) + self.assertIn('LIMIT', str(ss)) + self.assertNotIn('ORDER', str(ss)) def test_distinct(self): ss = SelectStatement('table', distinct_fields=['field2']) - ss.add_where_clause(WhereClause('field1', EqualsOperator(), 'b')) - self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', six.text_type(ss)) + ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b') + self.assertEqual(str(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', str(ss)) ss = SelectStatement('table', distinct_fields=['field1', 'field2']) - self.assertEqual(six.text_type(ss), 'SELECT DISTINCT "field1", "field2" FROM table') + self.assertEqual(str(ss), 'SELECT DISTINCT "field1", "field2" FROM table') ss = SelectStatement('table', distinct_fields=['field1'], count=True) - self.assertEqual(six.text_type(ss), 'SELECT DISTINCT COUNT("field1") FROM table') + self.assertEqual(str(ss), 'SELECT DISTINCT COUNT("field1") FROM table') def test_context(self): ss = SelectStatement('table') - ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') self.assertEqual(ss.get_context(), {'0': 'b'}) def test_context_id_update(self): """ tests that the right things happen the the context id """ ss = SelectStatement('table') - ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') self.assertEqual(ss.get_context(), {'0': 'b'}) self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s') @@ -91,20 +90,20 @@ def test_additional_rendering(self): limit=15, allow_filtering=True ) - qstr = six.text_type(ss) + qstr = str(ss) self.assertIn('LIMIT 15', qstr) self.assertIn('ORDER BY x, y', qstr) self.assertIn('ALLOW FILTERING', qstr) def test_limit_rendering(self): ss = SelectStatement('table', None, limit=10) - qstr = six.text_type(ss) + qstr = str(ss) self.assertIn('LIMIT 10', qstr) ss = SelectStatement('table', None, limit=0) - qstr = six.text_type(ss) + qstr = str(ss) self.assertNotIn('LIMIT', qstr) ss = SelectStatement('table', None, limit=None) - qstr = six.text_type(ss) + qstr = str(ss) self.assertNotIn('LIMIT', qstr) diff --git a/tests/integration/cqlengine/statements/test_update_statement.py b/tests/integration/cqlengine/statements/test_update_statement.py index a8ff865e25..4c6966b10f 100644 --- a/tests/integration/cqlengine/statements/test_update_statement.py +++ b/tests/integration/cqlengine/statements/test_update_statement.py @@ -1,26 +1,25 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest +from cassandra.cqlengine.columns import Column, Set, List, Text from cassandra.cqlengine.operators import * from cassandra.cqlengine.statements import (UpdateStatement, WhereClause, AssignmentClause, SetUpdateClause, ListUpdateClause) -import six class UpdateStatementTests(unittest.TestCase): @@ -28,59 +27,62 @@ class UpdateStatementTests(unittest.TestCase): def test_table_rendering(self): """ tests that fields are properly added to the select statement """ us = UpdateStatement('table') - self.assertTrue(six.text_type(us).startswith('UPDATE table SET'), six.text_type(us)) + self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) def test_rendering(self): us = UpdateStatement('table') - us.add_assignment_clause(AssignmentClause('a', 'b')) - us.add_assignment_clause(AssignmentClause('c', 'd')) - us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', six.text_type(us)) + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + self.assertEqual(str(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', str(us)) + + us.add_where(Column(db_field='a'), NotEqualsOperator(), 'y') + self.assertEqual(str(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s AND "a" != %(3)s', str(us)) def test_context(self): us = UpdateStatement('table') - us.add_assignment_clause(AssignmentClause('a', 'b')) - us.add_assignment_clause(AssignmentClause('c', 'd')) - us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'}) def test_context_update(self): us = UpdateStatement('table') - us.add_assignment_clause(AssignmentClause('a', 'b')) - us.add_assignment_clause(AssignmentClause('c', 'd')) - us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + us.add_assignment(Column(db_field='a'), 'b') + us.add_assignment(Column(db_field='c'), 'd') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') us.update_context_id(3) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') + self.assertEqual(str(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'}) def test_additional_rendering(self): us = UpdateStatement('table', ttl=60) - us.add_assignment_clause(AssignmentClause('a', 'b')) - us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) - self.assertIn('USING TTL 60', six.text_type(us)) + us.add_assignment(Column(db_field='a'), 'b') + us.add_where(Column(db_field='a'), EqualsOperator(), 'x') + self.assertIn('USING TTL 60', str(us)) def test_update_set_add(self): us = UpdateStatement('table') - us.add_assignment_clause(SetUpdateClause('a', set((1,)), operation='add')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + us.add_update(Set(Text, db_field='a'), set((1,)), 'add') + self.assertEqual(str(us), 'UPDATE table SET "a" = "a" + %(0)s') def test_update_empty_set_add_does_not_assign(self): us = UpdateStatement('table') - us.add_assignment_clause(SetUpdateClause('a', set(), operation='add')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + us.add_update(Set(Text, db_field='a'), set(), 'add') + self.assertFalse(us.assignments) def test_update_empty_set_removal_does_not_assign(self): us = UpdateStatement('table') - us.add_assignment_clause(SetUpdateClause('a', set(), operation='remove')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" - %(0)s') + us.add_update(Set(Text, db_field='a'), set(), 'remove') + self.assertFalse(us.assignments) def test_update_list_prepend_with_empty_list(self): us = UpdateStatement('table') - us.add_assignment_clause(ListUpdateClause('a', [], operation='prepend')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = %(0)s + "a"') + us.add_update(List(Text, db_field='a'), [], 'prepend') + self.assertFalse(us.assignments) def test_update_list_append_with_empty_list(self): us = UpdateStatement('table') - us.add_assignment_clause(ListUpdateClause('a', [], operation='append')) - self.assertEqual(six.text_type(us), 'UPDATE table SET "a" = "a" + %(0)s') + us.add_update(List(Text, db_field='a'), [], 'append') + self.assertFalse(us.assignments) diff --git a/tests/integration/cqlengine/statements/test_where_clause.py b/tests/integration/cqlengine/statements/test_where_clause.py index e3d95d4fa4..76eab13c3e 100644 --- a/tests/integration/cqlengine/statements/test_where_clause.py +++ b/tests/integration/cqlengine/statements/test_where_clause.py @@ -1,22 +1,20 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -import six from cassandra.cqlengine.operators import EqualsOperator from cassandra.cqlengine.statements import StatementException, WhereClause @@ -33,7 +31,7 @@ def test_where_clause_rendering(self): wc = WhereClause('a', EqualsOperator(), 'c') wc.set_context_id(5) - self.assertEqual('"a" = %(5)s', six.text_type(wc), six.text_type(wc)) + self.assertEqual('"a" = %(5)s', str(wc), str(wc)) self.assertEqual('"a" = %(5)s', str(wc), type(wc)) def test_equality_method(self): diff --git a/tests/integration/cqlengine/test_batch_query.py b/tests/integration/cqlengine/test_batch_query.py index 355a118235..26f312c50a 100644 --- a/tests/integration/cqlengine/test_batch_query.py +++ b/tests/integration/cqlengine/test_batch_query.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +15,7 @@ # limitations under the License. import warnings -import sure +import pytest from cassandra.cqlengine import columns from cassandra.cqlengine.management import drop_table, sync_table @@ -21,7 +23,7 @@ from cassandra.cqlengine.query import BatchQuery from tests.integration.cqlengine.base import BaseCassEngTestCase -from mock import patch +from unittest.mock import patch class TestMultiKeyModel(Model): partition = columns.Integer(primary_key=True) @@ -217,13 +219,13 @@ def test_callbacks_work_multiple_times(self): def my_callback(*args, **kwargs): call_history.append(args) - with warnings.catch_warnings(record=True) as w: + with pytest.warns() as w: with BatchQuery() as batch: batch.add_callback(my_callback) batch.execute() batch.execute() self.assertEqual(len(w), 2) # package filter setup to warn always - self.assertRegexpMatches(str(w[0].message), r"^Batch.*multiple.*") + self.assertRegex(str(w[0].message), r"^Batch.*multiple.*") def test_disable_multiple_callback_warning(self): """ diff --git a/tests/integration/cqlengine/test_connections.py b/tests/integration/cqlengine/test_connections.py new file mode 100644 index 0000000000..e767ece617 --- /dev/null +++ b/tests/integration/cqlengine/test_connections.py @@ -0,0 +1,672 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra import InvalidRequest +from cassandra.cluster import NoHostAvailable +from cassandra.cqlengine import columns, CQLEngineException +from cassandra.cqlengine import connection as conn +from cassandra.cqlengine.management import drop_keyspace, sync_table, drop_table, create_keyspace_simple +from cassandra.cqlengine.models import Model, QuerySetDescriptor +from cassandra.cqlengine.query import ContextQuery, BatchQuery, ModelQuerySet +from tests.integration.cqlengine import setup_connection, DEFAULT_KEYSPACE +from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.integration.cqlengine.query import test_queryset +from tests.integration import local, CASSANDRA_IP, TestCluster + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class AnotherTestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + +class ContextQueryConnectionTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ContextQueryConnectionTests, cls).setUpClass() + create_keyspace_simple('ks1', 1) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['1.2.3.4'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + with ContextQuery(TestModel, connection='cluster') as tm: + sync_table(tm) + + @classmethod + def tearDownClass(cls): + super(ContextQueryConnectionTests, cls).tearDownClass() + + with ContextQuery(TestModel, connection='cluster') as tm: + drop_table(tm) + drop_keyspace('ks1', connections=['cluster']) + + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_context_connection_priority(self): + """ + Tests to ensure the proper connection priority is honored. + + Explicit connection should have the highest priority, + Followed by context query connection + Default connection should be honored last. + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result priorities should be honored + + @test_category object_mapper + """ + # model keyspace write/read + + # Set the default connection on the Model + TestModel.__connection__ = 'cluster' + with ContextQuery(TestModel) as tm: + tm.objects.create(partition=1, cluster=1) + + # ContextQuery connection should have priority over default one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + # Explicit connection should have priority over ContextQuery one + with ContextQuery(TestModel, connection='fake_cluster') as tm: + tm.objects.using(connection='cluster').create(partition=1, cluster=1) + + # Reset the default conn of the model + TestModel.__connection__ = None + + # No model connection and an invalid default connection + with ContextQuery(TestModel) as tm: + with self.assertRaises(NoHostAvailable): + tm.objects.create(partition=1, cluster=1) + + def test_context_connection_with_keyspace(self): + """ + Tests to ensure keyspace param is honored + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Invalid request is thrown + + @test_category object_mapper + """ + + # ks2 doesn't exist + with ContextQuery(TestModel, connection='cluster', keyspace='ks2') as tm: + with self.assertRaises(InvalidRequest): + tm.objects.create(partition=1, cluster=1) + + +class ManagementConnectionTests(BaseCassEngTestCase): + + keyspaces = ['ks1', 'ks2'] + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(ManagementConnectionTests, cls).setUpClass() + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(ManagementConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_create_drop_keyspace(self): + """ + Tests drop and create keyspace with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result keyspaces should be created and dropped + + @test_category object_mapper + """ + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + create_keyspace_simple(self.keyspaces[0], 1) + + # Explicit connections + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + def test_create_drop_table(self): + """ + Tests drop and create Table with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Tables should be created and dropped + + @test_category object_mapper + """ + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + sync_table(TestModel) + + # Explicit connections + sync_table(TestModel, connections=self.conns) + + # Explicit drop + drop_table(TestModel, connections=self.conns) + + # Model connection + TestModel.__connection__ = 'cluster' + sync_table(TestModel) + TestModel.__connection__ = None + + # No connection (default is fake) + with self.assertRaises(NoHostAvailable): + drop_table(TestModel) + + # Model connection + TestModel.__connection__ = 'cluster' + drop_table(TestModel) + TestModel.__connection__ = None + + # Model connection + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + + def test_connection_creation_from_session(self): + """ + Test to ensure that you can register a connection from a session + @since 3.8 + @jira_ticket PYTHON-649 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + cluster = TestCluster() + session = cluster.connect() + connection_name = 'from_session' + conn.register_connection(connection_name, session=session) + self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + self.addCleanup(conn.unregister_connection, connection_name) + cluster.shutdown() + + def test_connection_from_hosts(self): + """ + Test to ensure that you can register a connection from a list of hosts + @since 3.8 + @jira_ticket PYTHON-692 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + connection_name = 'from_hosts' + conn.register_connection(connection_name, hosts=[CASSANDRA_IP]) + self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + self.addCleanup(conn.unregister_connection, connection_name) + + def test_connection_param_validation(self): + """ + Test to validate that invalid parameter combinations for registering connections via session are not tolerated + @since 3.8 + @jira_ticket PYTHON-649 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + cluster = TestCluster() + session = cluster.connect() + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection1", session=session, consistency="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection2", session=session, lazy_connect="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection3", session=session, retry_connect="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection4", session=session, cluster_options="not_null") + with self.assertRaises(CQLEngineException): + conn.register_connection("bad_coonection5", hosts="not_null", session=session) + cluster.shutdown() + + cluster.shutdown() + + + cluster.shutdown() + +class BatchQueryConnectionTests(BaseCassEngTestCase): + + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + super(BatchQueryConnectionTests, cls).setUpClass() + + create_keyspace_simple('ks1', 1) + sync_table(TestModel) + sync_table(AnotherTestModel) + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(BatchQueryConnectionTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + drop_keyspace('ks1') + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def test_basic_batch_query(self): + """ + Test Batch queries with connections explicitly set + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + + # No connection with a QuerySet (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Explicit connection with a QuerySet + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + + # Get an object from the BD + with ContextQuery(TestModel, connection='cluster') as tm: + obj = tm.objects.get(partition=1, cluster=1) + obj.__connection__ = None + + # No connection with a model (default is a fake one) + with self.assertRaises(NoHostAvailable): + with BatchQuery() as b: + obj.count = 2 + obj.batch(b).save() + + # Explicit connection with a model + with BatchQuery(connection='cluster') as b: + obj.count = 2 + obj.batch(b).save() + + def test_batch_query_different_connection(self): + """ + Test BatchQuery with Models that have a different connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result queries should execute appropriately + + @test_category object_mapper + """ + + # Testing on a model class + TestModel.__connection__ = 'cluster' + AnotherTestModel.__connection__ = 'cluster2' + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + TestModel.__connection__ = None + AnotherTestModel.__connection__ = None + + with BatchQuery(connection='cluster') as b: + TestModel.objects.batch(b).create(partition=1, cluster=1) + AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) + + # Testing on a model instance + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj2 = atm.objects.get(partition=1, cluster=1) + + obj1.__connection__ = 'cluster' + obj2.__connection__ = 'cluster2' + + obj1.count = 4 + obj2.count = 4 + + with self.assertRaises(CQLEngineException): + with BatchQuery() as b: + obj1.batch(b).save() + obj2.batch(b).save() + + def test_batch_query_connection_override(self): + """ + Test that we cannot override a BatchQuery connection per model + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Proper exceptions should be raised + + @test_category object_mapper + """ + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.batch(b).using(connection='test').save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + TestModel.using(connection='test').batch(b).save() + + with ContextQuery(TestModel, AnotherTestModel, connection='cluster') as (tm, atm): + obj1 = tm.objects.get(partition=1, cluster=1) + obj1.__connection__ = None + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.using(connection='test').batch(b).save() + + with self.assertRaises(CQLEngineException): + with BatchQuery(connection='cluster') as b: + obj1.batch(b).using(connection='test').save() + +class UsingDescriptorTests(BaseCassEngTestCase): + + conns = ['cluster'] + keyspaces = ['ks1', 'ks2'] + + @classmethod + def setUpClass(cls): + super(UsingDescriptorTests, cls).setUpClass() + + conn.unregister_connection('default') + conn.register_connection('fake_cluster', ['127.0.0.100'], lazy_connect=True, retry_connect=True, default=True) + conn.register_connection('cluster', [CASSANDRA_IP]) + + @classmethod + def tearDownClass(cls): + super(UsingDescriptorTests, cls).tearDownClass() + + # reset the default connection + conn.unregister_connection('fake_cluster') + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + + for ks in cls.keyspaces: + drop_keyspace(ks) + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + def _reset_data(self): + + for ks in self.keyspaces: + drop_keyspace(ks, connections=self.conns) + for ks in self.keyspaces: + create_keyspace_simple(ks, 1, connections=self.conns) + sync_table(TestModel, keyspaces=self.keyspaces, connections=self.conns) + + def test_keyspace(self): + """ + Test keyspace segregation when same connection is used + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result Keyspace segregation is honored + + @test_category object_mapper + """ + self._reset_data() + + with ContextQuery(TestModel, connection='cluster') as tm: + + # keyspace Model class + tm.objects.using(keyspace='ks2').create(partition=1, cluster=1) + tm.objects.using(keyspace='ks2').create(partition=2, cluster=2) + + with self.assertRaises(TestModel.DoesNotExist): + tm.objects.get(partition=1, cluster=1) # default keyspace ks1 + obj1 = tm.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj1.count = 2 + obj1.save() + + with self.assertRaises(NoHostAvailable): + TestModel.objects.using(keyspace='ks2').get(partition=1, cluster=1) + + obj2 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=1, cluster=1) + self.assertEqual(obj2.count, 2) + + # Update test + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').update(count=5) + obj3 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + self.assertEqual(obj3.count, 5) + + TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) + + def test_connection(self): + """ + Test basic connection functionality + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + self._reset_data() + + # Model class + with self.assertRaises(NoHostAvailable): + TestModel.objects.create(partition=1, cluster=1) + + TestModel.objects.using(connection='cluster').create(partition=1, cluster=1) + TestModel.objects(partition=1, cluster=1).using(connection='cluster').update(count=2) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 2) + + obj1.using(connection='cluster').update(count=5) + obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + self.assertEqual(obj1.count, 5) + + obj1.using(connection='cluster').delete() + with self.assertRaises(TestModel.DoesNotExist): + TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) + + +class ModelQuerySetNew(ModelQuerySet): + def __init__(self, *args, **kwargs): + super(ModelQuerySetNew, self).__init__(*args, **kwargs) + self._connection = "cluster" + +class BaseConnectionTestNoDefault(object): + conns = ['cluster'] + + @classmethod + def setUpClass(cls): + conn.register_connection('cluster', [CASSANDRA_IP]) + test_queryset.TestModel.__queryset__ = ModelQuerySetNew + test_queryset.IndexedTestModel.__queryset__ = ModelQuerySetNew + test_queryset.CustomIndexedTestModel.__queryset__ = ModelQuerySetNew + test_queryset.IndexedCollectionsTestModel.__queryset__ = ModelQuerySetNew + test_queryset.TestMultiClusteringModel.__queryset__ = ModelQuerySetNew + + super(BaseConnectionTestNoDefault, cls).setUpClass() + conn.unregister_connection('default') + + @classmethod + def tearDownClass(cls): + conn.unregister_connection('cluster') + setup_connection(DEFAULT_KEYSPACE) + super(BaseConnectionTestNoDefault, cls).tearDownClass() + # reset the default connection + + def setUp(self): + super(BaseCassEngTestCase, self).setUp() + + +class TestQuerySetOperationConnection(BaseConnectionTestNoDefault, test_queryset.TestQuerySetOperation): + """ + Execute test_queryset.TestQuerySetOperation using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetDistinctNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetDistinct): + """ + Execute test_queryset.TestQuerySetDistinct using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetOrderingNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetOrdering): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetCountSelectionAndIterationNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetCountSelectionAndIteration): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetSlicingNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetSlicing): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetValidationNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetValidation): + """ + Execute test_queryset.TestQuerySetOrdering using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestQuerySetDeleteNoDefault(BaseConnectionTestNoDefault, test_queryset.TestQuerySetDelete): + """ + Execute test_queryset.TestQuerySetDelete using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestValuesListNoDefault(BaseConnectionTestNoDefault, test_queryset.TestValuesList): + """ + Execute test_queryset.TestValuesList using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass + + +class TestObjectsPropertyNoDefault(BaseConnectionTestNoDefault, test_queryset.TestObjectsProperty): + """ + Execute test_queryset.TestObjectsProperty using non default connection + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result proper connection should be used + + @test_category object_mapper + """ + pass diff --git a/tests/integration/cqlengine/test_consistency.py b/tests/integration/cqlengine/test_consistency.py index 61cf3f0d50..3a6485eaed 100644 --- a/tests/integration/cqlengine/test_consistency.py +++ b/tests/integration/cqlengine/test_consistency.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock +from unittest import mock from uuid import uuid4 from cassandra import ConsistencyLevel as CL, ConsistencyLevel @@ -68,7 +70,6 @@ def test_update_uses_consistency(self): args = m.call_args self.assertEqual(CL.ALL, args[0][0].consistency_level) - def test_batch_consistency(self): with mock.patch.object(self.session, 'execute') as m: @@ -114,7 +115,7 @@ def test_delete(self): def test_default_consistency(self): # verify global assumed default - self.assertEqual(Session.default_consistency_level, ConsistencyLevel.LOCAL_ONE) + self.assertEqual(Session._default_consistency_level, ConsistencyLevel.LOCAL_ONE) # verify that this session default is set according to connection.setup # assumes tests/cqlengine/__init__ setup uses CL.ONE diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py new file mode 100644 index 0000000000..bb226f58ce --- /dev/null +++ b/tests/integration/cqlengine/test_context_query.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import ContextQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class ContextQueryTests(BaseCassEngTestCase): + + KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4') + + @classmethod + def setUpClass(cls): + super(ContextQueryTests, cls).setUpClass() + for ks in cls.KEYSPACES: + create_keyspace_simple(ks, 1) + sync_table(TestModel, keyspaces=cls.KEYSPACES) + + @classmethod + def tearDownClass(cls): + super(ContextQueryTests, cls).tearDownClass() + for ks in cls.KEYSPACES: + drop_keyspace(ks) + + + def setUp(self): + super(ContextQueryTests, self).setUp() + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + for obj in tm.all(): + obj.delete() + + def test_context_manager(self): + """ + Validates that when a context query is constructed that the + keyspace of the returned model is toggled appropriately + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(tm.__keyspace__, ks) + + self.assertEqual(TestModel._get_keyspace(), 'ks1') + + def test_default_keyspace(self): + """ + Tests the use of context queries with the default model keyspsace + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for i in range(5): + TestModel.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel) as tm: + self.assertEqual(5, len(tm.objects.all())) + + with ContextQuery(TestModel, keyspace='ks1') as tm: + self.assertEqual(5, len(tm.objects.all())) + + for ks in self.KEYSPACES[1:]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + def test_context_keyspace(self): + """ + Tests the use of context queries with non default keyspaces + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result queries should be routed to appropriate keyspaces + + @test_category query + """ + for i in range(5): + with ContextQuery(TestModel, keyspace='ks4') as tm: + tm.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel, keyspace='ks4') as tm: + self.assertEqual(5, len(tm.objects.all())) + + self.assertEqual(0, len(TestModel.objects.all())) + + for ks in self.KEYSPACES[:2]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + # simple data update + with ContextQuery(TestModel, keyspace='ks4') as tm: + obj = tm.objects.get(partition=1) + obj.update(count=42) + + self.assertEqual(42, tm.objects.get(partition=1).count) + + def test_context_multiple_models(self): + """ + Tests the use of multiple models with the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result all models are properly updated with the context + + @test_category query + """ + + with ContextQuery(TestModel, TestModel, keyspace='ks4') as (tm1, tm2): + + self.assertNotEqual(tm1, tm2) + self.assertEqual(tm1.__keyspace__, 'ks4') + self.assertEqual(tm2.__keyspace__, 'ks4') + + def test_context_invalid_parameters(self): + """ + Tests that invalid parameters are raised by the context manager + + @since 3.7 + @jira_ticket PYTHON-613 + @expected_result a ValueError is raised when passing invalid parameters + + @test_category query + """ + + with self.assertRaises(ValueError): + with ContextQuery(keyspace='ks2'): + pass + + with self.assertRaises(ValueError): + with ContextQuery(42) as tm: + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, 42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, unknown_param=42): + pass + + with self.assertRaises(ValueError): + with ContextQuery(TestModel, keyspace='ks2', unknown_param=42): + pass \ No newline at end of file diff --git a/tests/integration/cqlengine/test_ifexists.py b/tests/integration/cqlengine/test_ifexists.py index 6d693f7b45..32f48b58ff 100644 --- a/tests/integration/cqlengine/test_ifexists.py +++ b/tests/integration/cqlengine/test_ifexists.py @@ -1,22 +1,20 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -import mock +import unittest +from unittest import mock from uuid import uuid4 from cassandra.cqlengine import columns @@ -104,7 +102,7 @@ def test_update_if_exists(self): m = TestIfExistsModel.get(id=id) self.assertEqual(m.text, 'changed_again') - m = TestIfExistsModel(id=uuid4(), count=44) # do not exists + m = TestIfExistsModel(id=uuid4(), count=44) # do not exist with self.assertRaises(LWTException) as assertion: m.if_exists().update() @@ -159,7 +157,7 @@ def test_batch_update_if_exists_success(self): @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_mixed_update_if_exists_success(self): """ - Tests that batch update with with one bad query will still fail with LWTException + Tests that batch update with one bad query will still fail with LWTException @since 3.1 @jira_ticket PYTHON-432 @@ -181,7 +179,7 @@ def test_batch_mixed_update_if_exists_success(self): @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_delete_if_exists(self): """ - Tests that delete with if_exists work, and throw proper LWT exception when they are are not applied + Tests that delete with if_exists work, and throws proper LWT exception when they are not applied @since 3.1 @jira_ticket PYTHON-432 @@ -197,7 +195,7 @@ def test_delete_if_exists(self): q = TestIfExistsModel.objects(id=id) self.assertEqual(len(q), 0) - m = TestIfExistsModel(id=uuid4(), count=44) # do not exists + m = TestIfExistsModel(id=uuid4(), count=44) # do not exist with self.assertRaises(LWTException) as assertion: m.if_exists().delete() @@ -216,7 +214,7 @@ def test_delete_if_exists(self): @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_if_exists_success(self): """ - Tests that batch deletes with if_exists work, and throw proper LWTException when they are are not applied + Tests that batch deletes with if_exists work, and throws proper LWTException when they are not applied @since 3.1 @jira_ticket PYTHON-432 @@ -247,7 +245,7 @@ def test_batch_delete_if_exists_success(self): @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_mixed(self): """ - Tests that batch deletes with multiple queries and throw proper LWTException when they are are not all applicable + Tests that batch deletes with multiple queries and throws proper LWTException when they are not all applicable @since 3.1 @jira_ticket PYTHON-432 @@ -313,4 +311,3 @@ def test_instance_raise_exception(self): id = uuid4() with self.assertRaises(IfExistsWithCounterColumn): TestIfExistsWithCounterModel.if_exists() - diff --git a/tests/integration/cqlengine/test_ifnotexists.py b/tests/integration/cqlengine/test_ifnotexists.py index 8c5e89c185..793ca80355 100644 --- a/tests/integration/cqlengine/test_ifnotexists.py +++ b/tests/integration/cqlengine/test_ifnotexists.py @@ -1,22 +1,20 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -import mock +import unittest +from unittest import mock from uuid import uuid4 from cassandra.cqlengine import columns @@ -202,4 +200,3 @@ def test_instance_raise_exception(self): id = uuid4() with self.assertRaises(IfNotExistsWithCounterColumn): TestIfNotExistsWithCounterModel.if_not_exists() - diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py index 2d4e3181c7..91edce44c1 100644 --- a/tests/integration/cqlengine/test_lwt_conditional.py +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -1,23 +1,20 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -import mock -import six +import unittest +from unittest import mock from uuid import uuid4 from cassandra.cqlengine import columns @@ -27,7 +24,7 @@ from cassandra.cqlengine.statements import ConditionalClause from tests.integration.cqlengine.base import BaseCassEngTestCase -from tests.integration import CASSANDRA_VERSION +from tests.integration import greaterthancass20 class TestConditionalModel(Model): @@ -36,7 +33,14 @@ class TestConditionalModel(Model): text = columns.Text(required=False) -@unittest.skipUnless(CASSANDRA_VERSION >= '2.0.0', "conditionals only supported on cassandra 2.0 or higher") +class TestUpdateModel(Model): + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + value = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +@greaterthancass20 class TestConditional(BaseCassEngTestCase): @classmethod @@ -50,7 +54,7 @@ def tearDownClass(cls): drop_table(TestConditionalModel) def test_update_using_conditional(self): - t = TestConditionalModel.create(text='blah blah') + t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'new blah' with mock.patch.object(self.session, 'execute') as m: t.iff(text='blah blah').save() @@ -59,7 +63,7 @@ def test_update_using_conditional(self): self.assertIn('IF "text" = %(0)s', args[0][0].query_string) def test_update_conditional_success(self): - t = TestConditionalModel.create(text='blah blah', count=5) + t = TestConditionalModel.if_not_exists().create(text='blah blah', count=5) id = t.id t.text = 'new blah' t.iff(text='blah blah').save() @@ -69,7 +73,7 @@ def test_update_conditional_success(self): self.assertEqual(updated.text, 'new blah') def test_update_failure(self): - t = TestConditionalModel.create(text='blah blah') + t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'new blah' t = t.iff(text='something wrong') @@ -82,7 +86,7 @@ def test_update_failure(self): }) def test_blind_update(self): - t = TestConditionalModel.create(text='blah blah') + t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'something else' uid = t.id @@ -93,7 +97,7 @@ def test_blind_update(self): self.assertIn('IF "text" = %(1)s', args[0][0].query_string) def test_blind_update_fail(self): - t = TestConditionalModel.create(text='blah blah') + t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'something else' uid = t.id qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!') @@ -109,11 +113,11 @@ def test_conditional_clause(self): tc = ConditionalClause('some_value', 23) tc.set_context_id(3) - self.assertEqual('"some_value" = %(3)s', six.text_type(tc)) + self.assertEqual('"some_value" = %(3)s', str(tc)) self.assertEqual('"some_value" = %(3)s', str(tc)) def test_batch_update_conditional(self): - t = TestConditionalModel.create(text='something', count=5) + t = TestConditionalModel.if_not_exists().create(text='something', count=5) id = t.id with BatchQuery() as b: t.batch(b).iff(count=5).update(text='something else') @@ -135,9 +139,30 @@ def test_batch_update_conditional(self): updated = TestConditionalModel.objects(id=id).first() self.assertEqual(updated.text, 'something else') + @unittest.skip("Skipping until PYTHON-943 is resolved") + def test_batch_update_conditional_several_rows(self): + sync_table(TestUpdateModel) + self.addCleanup(drop_table, TestUpdateModel) + + first_row = TestUpdateModel.create(partition=1, cluster=1, value=5, text="something") + second_row = TestUpdateModel.create(partition=1, cluster=2, value=5, text="something") + + b = BatchQuery() + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=1, value=5, text='something else') + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=2, value=5, text='something else') + TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=3, value=5, text='something else') + + # The response will be more than two rows because two of the inserts will fail + with self.assertRaises(LWTException): + b.execute() + + first_row.delete() + second_row.delete() + b.execute() + def test_delete_conditional(self): # DML path - t = TestConditionalModel.create(text='something', count=5) + t = TestConditionalModel.if_not_exists().create(text='something', count=5) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) with self.assertRaises(LWTException): t.iff(count=9999).delete() @@ -146,7 +171,7 @@ def test_delete_conditional(self): self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) # QuerySet path - t = TestConditionalModel.create(text='something', count=5) + t = TestConditionalModel.if_not_exists().create(text='something', count=5) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) with self.assertRaises(LWTException): TestConditionalModel.objects(id=t.id).iff(count=9999).delete() @@ -154,13 +179,69 @@ def test_delete_conditional(self): TestConditionalModel.objects(id=t.id).iff(count=5).delete() self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + def test_delete_lwt_ne(self): + """ + Test to ensure that deletes using IF and not equals are honored correctly + + @since 3.2 + @jira_ticket PYTHON-328 + @expected_result Delete conditional with NE should be honored + + @test_category object_mapper + """ + + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count__ne=5).delete() + t.iff(count__ne=2).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count__ne=5).delete() + TestConditionalModel.objects(id=t.id).iff(count__ne=2).delete() + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + + def test_update_lwt_ne(self): + """ + Test to ensure that update using IF and not equals are honored correctly + + @since 3.2 + @jira_ticket PYTHON-328 + @expected_result update conditional with NE should be honored + + @test_category object_mapper + """ + + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + t.iff(count__ne=5).update(text='nothing') + t.iff(count__ne=2).update(text='nothing') + self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + t.delete() + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + with self.assertRaises(LWTException): + TestConditionalModel.objects(id=t.id).iff(count__ne=5).update(text='nothing') + TestConditionalModel.objects(id=t.id).iff(count__ne=2).update(text='nothing') + self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + t.delete() + def test_update_to_none(self): # This test is done because updates to none are split into deletes # for old versions of cassandra. Can be removed when we drop that code # https://github.com/datastax/python-driver/blob/3.1.1/cassandra/cqlengine/query.py#L1197-L1200 # DML path - t = TestConditionalModel.create(text='something', count=5) + t = TestConditionalModel.if_not_exists().create(text='something', count=5) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) with self.assertRaises(LWTException): t.iff(count=9999).update(text=None) @@ -169,10 +250,46 @@ def test_update_to_none(self): self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) # QuerySet path - t = TestConditionalModel.create(text='something', count=5) + t = TestConditionalModel.if_not_exists().create(text='something', count=5) self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) with self.assertRaises(LWTException): TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None) self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + + def test_column_delete_after_update(self): + # DML path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + t.iff(count=5).update(text=None, count=6) + + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) + + # QuerySet path + t = TestConditionalModel.if_not_exists().create(text='something', count=5) + TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) + + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6) + + def test_conditional_without_instance(self): + """ + Test to ensure that the iff method is honored if it's called + directly from the Model class + + @jira_ticket PYTHON-505 + @expected_result the value is updated + + @test_category object_mapper + """ + uuid = uuid4() + TestConditionalModel.if_not_exists().create(id=uuid, text='test_for_cassandra', count=5) + + # This uses the iff method directly from the model class without + # an instance having been created + TestConditionalModel.iff(count=5).filter(id=uuid).update(text=None, count=6) + + t = TestConditionalModel.filter(id=uuid).first() + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) diff --git a/tests/integration/cqlengine/test_timestamp.py b/tests/integration/cqlengine/test_timestamp.py index 9b3ffbeafa..c68fc8fa5b 100644 --- a/tests/integration/cqlengine/test_timestamp.py +++ b/tests/integration/cqlengine/test_timestamp.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +15,7 @@ # limitations under the License. from datetime import timedelta, datetime -import mock -import sure +from unittest import mock from uuid import uuid4 from cassandra.cqlengine import columns @@ -44,7 +45,7 @@ def test_batch_is_included(self): with BatchQuery(timestamp=timedelta(seconds=30)) as b: TestTimestampModel.batch(b).create(count=1) - "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + self.assertIn("USING TIMESTAMP", m.call_args[0][0].query_string) class CreateWithTimestampTest(BaseTimestampTest): @@ -56,23 +57,27 @@ def test_batch(self): query = m.call_args[0][0].query_string - query.should.match(r"INSERT.*USING TIMESTAMP") - query.should_not.match(r"TIMESTAMP.*INSERT") + self.assertRegex(query, r"INSERT.*USING TIMESTAMP") + self.assertNotRegex(query, r"TIMESTAMP.*INSERT") def test_timestamp_not_included_on_normal_create(self): with mock.patch.object(self.session, "execute") as m: TestTimestampModel.create(count=2) - "USING TIMESTAMP".shouldnt.be.within(m.call_args[0][0].query_string) + self.assertNotIn("USING TIMESTAMP", m.call_args[0][0].query_string) def test_timestamp_is_set_on_model_queryset(self): delta = timedelta(seconds=30) tmp = TestTimestampModel.timestamp(delta) - tmp._timestamp.should.equal(delta) + self.assertEqual(tmp._timestamp, delta) def test_non_batch_syntax_integration(self): tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1) - tmp.should.be.ok + self.assertIsNotNone(tmp) + + def test_non_batch_syntax_with_tll_integration(self): + tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).ttl(30).create(count=1) + self.assertIsNotNone(tmp) def test_non_batch_syntax_unit(self): @@ -81,7 +86,17 @@ def test_non_batch_syntax_unit(self): query = m.call_args[0][0].query_string - "USING TIMESTAMP".should.be.within(query) + self.assertIn("USING TIMESTAMP", query) + + def test_non_batch_syntax_with_ttl_unit(self): + + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.timestamp(timedelta(seconds=30)).ttl(30).create( + count=1) + + query = m.call_args[0][0].query_string + + self.assertRegex(query, r"USING TTL \d* AND TIMESTAMP") class UpdateWithTimestampTest(BaseTimestampTest): @@ -95,7 +110,7 @@ def test_instance_update_includes_timestamp_in_query(self): with mock.patch.object(self.session, "execute") as m: self.instance.timestamp(timedelta(seconds=30)).update(count=2) - "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + self.assertIn("USING TIMESTAMP", m.call_args[0][0].query_string) def test_instance_update_in_batch(self): with mock.patch.object(self.session, "execute") as m: @@ -103,7 +118,7 @@ def test_instance_update_in_batch(self): self.instance.batch(b).timestamp(timedelta(seconds=30)).update(count=2) query = m.call_args[0][0].query_string - "USING TIMESTAMP".should.be.within(query) + self.assertIn("USING TIMESTAMP", query) class DeleteWithTimestampTest(BaseTimestampTest): @@ -115,7 +130,7 @@ def test_non_batch(self): uid = uuid4() tmp = TestTimestampModel.create(id=uid, count=1) - TestTimestampModel.get(id=uid).should.be.ok + self.assertIsNotNone(TestTimestampModel.get(id=uid)) tmp.timestamp(timedelta(seconds=5)).delete() @@ -129,15 +144,15 @@ def test_non_batch(self): # calling .timestamp sets the TS on the model tmp.timestamp(timedelta(seconds=5)) - tmp._timestamp.should.be.ok + self.assertIsNotNone(tmp._timestamp) # calling save clears the set timestamp tmp.save() - tmp._timestamp.shouldnt.be.ok + self.assertIsNone(tmp._timestamp) tmp.timestamp(timedelta(seconds=5)) tmp.update() - tmp._timestamp.shouldnt.be.ok + self.assertIsNone(tmp._timestamp) def test_blind_delete(self): """ @@ -146,7 +161,7 @@ def test_blind_delete(self): uid = uuid4() tmp = TestTimestampModel.create(id=uid, count=1) - TestTimestampModel.get(id=uid).should.be.ok + self.assertIsNotNone(TestTimestampModel.get(id=uid)) TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=5)).delete() @@ -165,7 +180,7 @@ def test_blind_delete_with_datetime(self): uid = uuid4() tmp = TestTimestampModel.create(id=uid, count=1) - TestTimestampModel.get(id=uid).should.be.ok + self.assertIsNotNone(TestTimestampModel.get(id=uid)) plus_five_seconds = datetime.now() + timedelta(seconds=5) @@ -183,11 +198,9 @@ def test_delete_in_the_past(self): uid = uuid4() tmp = TestTimestampModel.create(id=uid, count=1) - TestTimestampModel.get(id=uid).should.be.ok + self.assertIsNotNone(TestTimestampModel.get(id=uid)) - # delete the in past, should not affect the object created above + # delete in the past, should not affect the object created above TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=-60)).delete() TestTimestampModel.get(id=uid) - - diff --git a/tests/integration/cqlengine/test_ttl.py b/tests/integration/cqlengine/test_ttl.py index ba2c1e0935..0e0f8d2c28 100644 --- a/tests/integration/cqlengine/test_ttl.py +++ b/tests/integration/cqlengine/test_ttl.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +15,19 @@ # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest + +from packaging.version import Version +from cassandra import InvalidRequest from cassandra.cqlengine.management import sync_table, drop_table from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.models import Model from uuid import uuid4 from cassandra.cqlengine import columns -import mock +from unittest import mock from cassandra.cqlengine.connection import get_session -from tests.integration import CASSANDRA_VERSION +from tests.integration import CASSANDRA_VERSION, greaterthancass20 class TestTTLModel(Model): @@ -58,14 +60,14 @@ class BaseDefaultTTLTest(BaseCassEngTestCase): @classmethod def setUpClass(cls): - if CASSANDRA_VERSION >= '2.0': + if CASSANDRA_VERSION >= Version('2.0'): super(BaseDefaultTTLTest, cls).setUpClass() sync_table(TestDefaultTTLModel) sync_table(TestTTLModel) @classmethod def tearDownClass(cls): - if CASSANDRA_VERSION >= '2.0': + if CASSANDRA_VERSION >= Version('2.0'): super(BaseDefaultTTLTest, cls).tearDownClass() drop_table(TestDefaultTTLModel) drop_table(TestTTLModel) @@ -156,8 +158,17 @@ def test_ttl_included_with_blind_update(self): self.assertIn("USING TTL", query) -@unittest.skipIf(CASSANDRA_VERSION < '2.0', "default_time_to_Live was introduce in C* 2.0, currently running {0}".format(CASSANDRA_VERSION)) class TTLDefaultTest(BaseDefaultTTLTest): + def get_default_ttl(self, table_name): + session = get_session() + try: + default_ttl = session.execute("SELECT default_time_to_live FROM system_schema.tables " + "WHERE keyspace_name = 'cqlengine_test' AND table_name = '{0}'".format(table_name)) + except InvalidRequest: + default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies " + "WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name)) + return default_ttl[0]['default_time_to_live'] + def test_default_ttl_not_set(self): session = get_session() @@ -166,36 +177,60 @@ def test_default_ttl_not_set(self): self.assertIsNone(o._ttl) + default_ttl = self.get_default_ttl('test_ttlmodel') + self.assertEqual(default_ttl, 0) + with mock.patch.object(session, 'execute') as m: - TestTTLModel.objects(id=tid).update(text="aligators") + TestTTLModel.objects(id=tid).update(text="alligators") query = m.call_args[0][0].query_string self.assertNotIn("USING TTL", query) def test_default_ttl_set(self): session = get_session() + o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id - self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) + # Should not be set, it's handled by Cassandra + self.assertIsNone(o._ttl) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) with mock.patch.object(session, 'execute') as m: - TestDefaultTTLModel.objects(id=tid).update(text="aligators expired") + TestTTLModel.objects(id=tid).update(text="alligators expired") + # Should not be set either query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + self.assertNotIn("USING TTL", query) + + def test_default_ttl_modify(self): + session = get_session() + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) + + TestDefaultTTLModel.__options__ = {'default_time_to_live': 10} + sync_table(TestDefaultTTLModel) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 10) + + # Restore default TTL + TestDefaultTTLModel.__options__ = {'default_time_to_live': 20} + sync_table(TestDefaultTTLModel) def test_override_default_ttl(self): session = get_session() o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id - self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) o.ttl(3600) self.assertEqual(o._ttl, 3600) with mock.patch.object(session, 'execute') as m: - TestDefaultTTLModel.objects(id=tid).ttl(None).update(text="aligators expired") + TestDefaultTTLModel.objects(id=tid).ttl(None).update(text="alligators expired") query = m.call_args[0][0].query_string self.assertNotIn("USING TTL", query) diff --git a/tests/integration/datatype_utils.py b/tests/integration/datatype_utils.py index ee9695c289..a4c4cdb4d8 100644 --- a/tests/integration/datatype_utils.py +++ b/tests/integration/datatype_utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,9 +16,10 @@ from decimal import Decimal from datetime import datetime, date, time +import ipaddress from uuid import uuid1, uuid4 -from cassandra.util import OrderedMap, Date, Time, sortedset +from cassandra.util import OrderedMap, Date, Time, sortedset, Duration from tests.integration import get_server_versions @@ -39,6 +42,8 @@ 'varint', ]) +PRIMITIVE_DATATYPES_KEYS = PRIMITIVE_DATATYPES.copy() + COLLECTION_TYPES = sortedset([ 'list', 'set', @@ -54,6 +59,9 @@ def update_datatypes(): if _cass_version >= (2, 2, 0): PRIMITIVE_DATATYPES.update(['date', 'time', 'smallint', 'tinyint']) + PRIMITIVE_DATATYPES_KEYS.update(['date', 'time', 'smallint', 'tinyint']) + if _cass_version >= (3, 10): + PRIMITIVE_DATATYPES.add('duration') global SAMPLE_DATA SAMPLE_DATA = get_sample_data() @@ -85,7 +93,10 @@ def get_sample_data(): sample_data[datatype] = 3.4028234663852886e+38 elif datatype == 'inet': - sample_data[datatype] = '123.123.123.123' + sample_data[datatype] = ('123.123.123.123', + '2001:db8:85a3:8d3:1319:8a2e:370:7348', + ipaddress.IPv4Address("123.123.123.123"), + ipaddress.IPv6Address('2001:db8:85a3:8d3:1319:8a2e:370:7348')) elif datatype == 'int': sample_data[datatype] = 2147483647 @@ -120,6 +131,9 @@ def get_sample_data(): elif datatype == 'smallint': sample_data[datatype] = 32523 + elif datatype == 'duration': + sample_data[datatype] = Duration(months=2, days=12, nanoseconds=21231) + else: raise Exception("Missing handling of {0}".format(datatype)) @@ -132,10 +146,20 @@ def get_sample(datatype): """ Helper method to access created sample data for primitive types """ - + if isinstance(SAMPLE_DATA[datatype], tuple): + return SAMPLE_DATA[datatype][0] return SAMPLE_DATA[datatype] +def get_all_samples(datatype): + """ + Helper method to access created sample data for primitive types + """ + if isinstance(SAMPLE_DATA[datatype], tuple): + return SAMPLE_DATA[datatype] + return SAMPLE_DATA[datatype], + + def get_collection_sample(collection_type, datatype): """ Helper method to access created sample data for collection types diff --git a/tests/integration/long/__init__.py b/tests/integration/long/__init__.py index caa7e71667..f369b97a81 100644 --- a/tests/integration/long/__init__.py +++ b/tests/integration/long/__init__.py @@ -1,16 +1,20 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest + try: from ccmlib import common except ImportError as e: diff --git a/tests/integration/long/ssl/127.0.0.1.keystore b/tests/integration/long/ssl/127.0.0.1.keystore new file mode 100644 index 0000000000..98193ab54e Binary files /dev/null and b/tests/integration/long/ssl/127.0.0.1.keystore differ diff --git a/tests/integration/long/ssl/cassandra.crt b/tests/integration/long/ssl/cassandra.crt deleted file mode 100644 index 432e58540b..0000000000 Binary files a/tests/integration/long/ssl/cassandra.crt and /dev/null differ diff --git a/tests/integration/long/ssl/cassandra.truststore b/tests/integration/long/ssl/cassandra.truststore new file mode 100644 index 0000000000..b31e34b8aa Binary files /dev/null and b/tests/integration/long/ssl/cassandra.truststore differ diff --git a/tests/integration/long/ssl/client.crt_signed b/tests/integration/long/ssl/client.crt_signed new file mode 100644 index 0000000000..db3d903f19 --- /dev/null +++ b/tests/integration/long/ssl/client.crt_signed @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDDjCCAfYCFAG4WryLorTXxNtrkEJ56zUg/XdDMA0GCSqGSIb3DQEBCwUAMEIx +CzAJBgNVBAYTAlVTMREwDwYDVQQKDAhkYXRhc3RheDEPMA0GA1UECwwGZmllbGRz +MQ8wDQYDVQQDDAZyb290Q2EwHhcNMjEwMzE3MTcwNTE4WhcNMjIwMzE3MTcwNTE4 +WjBFMQswCQYDVQQGEwJVUzERMA8GA1UECgwIZGF0YXN0YXgxDzANBgNVBAsMBmZp +ZWxkczESMBAGA1UEAwwJMTI3LjAuMC4xMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAnrpE3g8pbQn2tVVidX2Ww1rh/6YIH6EGW9hXMO/F506ReMruv+Al +ilc7B2sPpGRDKXupy23IcpfMIe9+Lm74/yu7pW51rJ/r2jMqg+tViFa/GQxSQLKd +AxDAvwJaAM41kro0DKmcm4RwfYAltupwc6pC7AfBtT08PBuDK7WfaNnFbhGAWkHv +MbULNWAKbPWqITHbUEvLgS/uPj+/W4SHk5GaYk0Y2mU3aWypeDOBqEfKTi2W0ix1 +O7SpOHyfA0hvXS9IilF/HWURvr9u13mnvJNe8W+uqWqlQMdyFsbPCIhbVwVwGYQp +yoyBrgz6y5SPwSyugAb2F8Yk3UpvqH30yQIDAQABMA0GCSqGSIb3DQEBCwUAA4IB +AQB5XV+3NS5UpwpTXTYsadLL8XcdGsfITMs4MSv0N3oir++TUzTc3cOd2T6YVdEc +ypw5CKTYnFTK9oF2PZXeV+aLIjdvK4AukQurB8EdXq4Hu7y1b61OaGRqiKTVsIne +LwxCXpc42jqMFt4mMXpmU/hSCjRSvoumTcL1aHUzaPlSIasD2JDyLurO64gxQypi +wbD9gliPJ60pdhY0m9NfF5F2PdqBuJXrhF1VuxYx1/cfo/c1A4UK2slhsZCDls7/ +HbM8ri5Z74M1EtCGFcTNYvm0xlfF5arisGQSKhTw+06LnpUlQi5a8NRNBLeAmem/ +cuICJJbnSzjmq9skkp8i/ejH +-----END CERTIFICATE----- diff --git a/tests/integration/long/ssl/client.key b/tests/integration/long/ssl/client.key new file mode 100644 index 0000000000..d6b8811a94 --- /dev/null +++ b/tests/integration/long/ssl/client.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCeukTeDyltCfa1 +VWJ1fZbDWuH/pggfoQZb2Fcw78XnTpF4yu6/4CWKVzsHaw+kZEMpe6nLbchyl8wh +734ubvj/K7ulbnWsn+vaMyqD61WIVr8ZDFJAsp0DEMC/AloAzjWSujQMqZybhHB9 +gCW26nBzqkLsB8G1PTw8G4MrtZ9o2cVuEYBaQe8xtQs1YAps9aohMdtQS8uBL+4+ +P79bhIeTkZpiTRjaZTdpbKl4M4GoR8pOLZbSLHU7tKk4fJ8DSG9dL0iKUX8dZRG+ +v27Xeae8k17xb66paqVAx3IWxs8IiFtXBXAZhCnKjIGuDPrLlI/BLK6ABvYXxiTd +Sm+offTJAgMBAAECggEAN+VysRx3wy1aEvuRo7xpZjxQD/5BKBpFqfxioBogAFfb +xMT6FNnzfmc/o1ohdQvV1vr0jW4Iw8oPGfhD4Eg2KW4WM6jVicf7f6i7FR+/zDZ4 +L3L2WFBOGLFCn0FNvrDfjt9Byx/DxcR69Mc3ANZIaYMQ9Bu7LH73AlfR9oeMLpjL ++6g1qz2yz8Sm2CMCGXTyXtvUCgn2ld6nz8KlZ8FTUG9C9mAabuvV91Ko6rmTxuiv +YKvHSPnIjXRjuC+Ozjf1rYTOJ5LVMNNhlbIKBG/Nx5QzL7bA3XDtMD1BEI9pdHR+ +5HwA0tV2Ex67tBCJwlBAhYLxuPjfOj1R5KV8wriE3QKBgQDNvqOaGYiXwp9Rajoo +ltlOBPfnjshd9tPdc6tTUQR34vSbkHrg0HVJhvIP5LRbyx/M/8ACQxFkDRE4U7fJ +xVGDs8Pi0FqcqFTnm/AYQ5eZbJkPp9qe71aDOPanncrVNEGFeW26LaeLGbTLrOMM +6mTmsfGig0MKgml35IMrP+oPuwKBgQDFf56DdaFe08xSK9pDWuKxUuBIagGExQkQ +r9eYasBc336CXh3FWtpSlxl73dqtISh/HbKbv+OZfkVdbmkcTVGlWm/N/XvLqpPK +86kbKW6PY8FxIY/RxiZANf/JJ5gzPp6VQMJeSy+oepeWj11mTLcT02plvIMM0Jmg +Z5B9Hw37SwKBgDR/59lDmLI47FRnCc4fp/WbmPKSYZhwimFgyZ/p9XzuAcLMXD6P +ks4fTBc4IbmmnEfAHuu013QzTWiVHDm1SvaTYXG3/tcosPmkteBLJxz0NB5lk4io +w+eaGn5s6jv7KJj5gkFWswDwn0y1of5CtVqUn3b7jZjZ7DW2rq3TklNPAoGAIzaW +56+AfyzaQEhrWRkKVD2HmcG01Zxf+mav1RArjiOXJd1sB3UkehdQxuIOjFHeK5P6 +9YQoK4T1DyyRdydeCFJwntS0TuLyCPyaySoA+XX61pX6U5e12DsIiTATFgfzNH9g +aHmVXL/G6WRUbdn9xn4qeUs8Pnuu+IeenoB7+LMCgYBBnig9nTp81U+SGsNl2D3J +WUz4z+XzEfKU1nq2s4KNjIPB2T1ne+1x3Uso2hagtEHeuEbZoRY4dtCahAvYwrPM +8wtDFQXWmvFyN3X0Js65GZ++knuseQ1tdlbc/4C+k4u26tVe2GcwhKTjn08++L2E +UB3pLXbssswH271OjD+QkQ== +-----END PRIVATE KEY----- diff --git a/tests/integration/long/ssl/client_bad.key b/tests/integration/long/ssl/client_bad.key new file mode 100644 index 0000000000..5d810f25cf --- /dev/null +++ b/tests/integration/long/ssl/client_bad.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDha8+NBvpTmTbw +D2EIodXlaaAEtLmXTGoH8pdBm3JxzMuUEkYbGig3YjQ1BAKQgCB1TJGPINcHz7Jo +5aW5To1jrxhhohZYQLCNKlAONDhgJbHEPf7s8dreQ/q5ISt/2I3z68c9I0j3VoRz +AxxcNktl/x+6YkXe9tXf/LWmJk/gHlu72/HuJ5oNyqOKaCCoMoib3jLTlR+lslTy +Qy/CJZH6WJabLOPmppFLaxJMlSGDSTE/Xktt7+H5ssHnfQtyWyylVjZkzChJfBgh +HrLpm3hO5rmqVwOhoKLKVDFMmX3aMGX2S+3KpXQ8gLnPXwfLI9J9fDg5jp7bya4k +OXlZfB5hAgMBAAECggEBANQVFbmudfgPL4PeREHV2SM1JCspSW9SonOFxs8gDCWL +M4HFS5YWHv40c7/pXOxMz7zsZApQMF8WBtnwLeJRSG8f/oVk9Tbk7fZyd81VTjEP +ZdenKGAPEAeL16kzzvRCbxOtoc8gkna6PHTk2VrcbkWxKU23RduHSiOpY9HFO+Mz +iI69tB7657NOiZCQ6xDIjKv+jR63m7VAWKT5jkN+tYpvx4K20na5t8RO1s0shqNE +e2zMG8WXVl6lW4btfkt/lwWUNXu8olMTk9qN2b5Rq7BEJfKwn3lb9vCpUMyewtRB +/8U+Zu7Tlwni5QagOqAUEkjuOJ8cR/Jgwu1mqV2sXxECgYEA9zXi0PjWAe2ZIALd +1iWPZCvvT7yEjt4ulhAYPqi8T38B4K5f//m5SuYPS2ebmSAd2WBTeIX2A6mHc9lk +53gnwvsgAqaFgjYeDqBThpCE8icFXEZnJbtnJyC8zC7pYjUovAHkFEdLw5kQoI6Y +i9HNOS9ugSut8RnF0oSv/E2mahUCgYEA6W+ZAEneBBCsOQclVfVPLm7D+y+5SZEt +zWr2b7CCnGCev/qRCllIEwQ2+W1ACEHof9xjE+aWwEQjX8YnoVbAJo2ru6FFQfI+ +f/SQx7beX8jUAeJGo+CFr2ijdVmcCCbMGeAm8mpACUIQfWPHVqjtGS/CayxdfwA+ +lbWPbkXCMh0CgYBfUgHRPgGW4LyoYTKUfgsaPu6ZukEKrZUc+7u9fWaO6JQaxGHz +26CcxrSjCKIwmvend8L3t/+yTc4S14JW1jfOsPIY04irOp7AWQWb32HD1VP1zpe7 +LtWJetARkw0edwzr4XbGcu89zmlg31rmntEY+bcMS4FYc+2ZTNxm1rISOQKBgGQZ +lct44Xpux9tghBMbMUwg9WtWKKcyWSi4EFsOnsN97zU1tlJwvKZi7UwCHC4uTQvf +LqFPBSAHV//u0fmuYJFnuNeprTA9N63Y6uipMyxxyu/P3yjQ06LHRSjCN1WLhYQn +Cax0AWe266lJSyaPI7TkNQOOL72RFkVOaOYJhd/FAoGAPtpVPTiVK0RYwLnZqaWB +fxyI6w+UjOEbP88vD7N7FEI2kQSGQ6F3pMzDK37NglJVtwjgzEIF9x9BIE8XSf16 +shc0U73Vg9ZsXDNPUz21hhAwYL1cCgnx0mfL88F1Icb5FfxlT/1BPHNHKowA9vST +ihbxCJg/JJBzwXTxPocQisk= +-----END PRIVATE KEY----- diff --git a/tests/integration/long/ssl/client_encrypted.key b/tests/integration/long/ssl/client_encrypted.key new file mode 100644 index 0000000000..49f475d7fe --- /dev/null +++ b/tests/integration/long/ssl/client_encrypted.key @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,7288A409E846EBE2DE421B77598DAF98 + +ahiUSf+k9POIEUJb5BGbQ6knk4+FtTz+e+6fouqVc4Lq+RXR9f0pFBi9eDEkFNiN +AcUjLkxh+3TmihTZJprqXSbQ3jacwbnwDOFgtZE3PxoA1heHxADaKCNr+Ph0lC/T +3cIzsoIZ6slk+3n6ERieZRdmvoMH1SY8nXKT5+bLMR4RIjw1y7h26MRhjQS+lXaX +Asd5EOGROCIgefeEBGHAbrlg0FoHy7slqVBxuZphTHKtyK/VK4fRLt6doUzBu5GJ +T2jdrqJCWr5PRn3bAqMemJWxDhZLX4DyNDQPn8riZ8jMbwPOVUSnF8B8re1tNkQ0 +CsH77sYIIjmPdizCdvj91+jH6o7MRCZPvky+PHG/9G5WsPiw5W1i/nrPemT1XJyy +oPRc/fMFfbHmW3HCGqgv2/6Wg+17un/a6UyzXsbNdhDZLCVqtAQ7PSv83z5oUazT +djzFHgxSqRknUY0lOUvP8Rni67MG+Rcksj9HgszhLoC0be64IX0Ey5oc5+pBYrf9 +FVEPsuyyu4aDSRYYATC2E1V/EQRwcvpKEZNFTbqMpQhjrWtlBM/GgQnQBeQdLAGX +yefDSzkH31y5gcdgHLElriWwbHHbcaAmf3e15W94YHgTytJBsQ9A19SmtmgUmo4h +jaFoUooM5mFA8hc/snSe2PdkEefkzS72g8qxa//61LTJAAkVk43dYjoqQ34wq6WR +OB4nn/W2xlfv/ClZJTWf8YvQTrQptJY5VQq/TTEcrXy67Uc0wRHXZK2rTjKeyRj9 +65SkyyXhMopWEl2vX25ReITVfdJ0FgjqI/ugYSf25iOfJtsk+jgrtrswZ+8F2eMq +iAQ+0JSiYmlot2Pn1QCalLjtTz8zeMfXPyo5fbKNMdp52U1cPYld90kUGHZfjqju +GmY/aHa6N8lZGxj8SC/JM36GawaGKe4S/F5BetYJOpaEzkpowqlTC8Syv529rm46 +vvgf+EJL8gRvdtnIEe/qtzbtel299VhaBpuOcApfTDSxRHZmvkCpdHo9I3KgOZB9 +Cqu9Bz+FiJmTk8rGQwmI8EYj38jneEoqA+fN7tUkzxCGacg+x6ke4nOcJzgBhd94 +8DvGclrcAwBY1mlNYRceFJKFXhwLZTKBojZlS8Q9863EAH3DOBLeP85V3YvBD/MK +O+kzPoxN/jPVNho7y4gL7skcqe/IXePzPxBcZrHJjoU7mGVDcVcouRj16XSezMbB +5Pft0/gGiItRJ2+v9DlPjzDfjTuRdS78muaZ4nNqX6B+JmyPJtkb2CdiHz6B21RO +3hjGrffM1nhmYBegyjTVc88IxzYg0T8CZLq1FYxuTZmwyahA520IpwsbfwXxLVMU +5rmou5dj1pVlvoP3l+ivPqugeY3k7UjZ33m5H9p009JR40dybr1S2RbI8Gqhe953 +0bedA4DWvPakODXgYu43al92uR/tyjazeB5t7Iu8uB5Xcm3/Mqoofe9xtdQSCWa0 +jKKvXzSpL1MM2C0bRyYHIkVR65K7Zmi/BzvTaPECo1+Uv+EwqRZRyBzUZKPP8LMq +jTCOBmYaK8+0dTRk8MEzrPW2ihVVJYVMmFyTZKW0iK7kOMKZRkhDCaNSUlPEty7j +-----END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/driver_ca_cert.pem b/tests/integration/long/ssl/driver_ca_cert.pem deleted file mode 100644 index 7e55555767..0000000000 --- a/tests/integration/long/ssl/driver_ca_cert.pem +++ /dev/null @@ -1,16 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICbjCCAdegAwIBAgIEP/N06DANBgkqhkiG9w0BAQsFADBqMQswCQYDVQQGEwJU -RTELMAkGA1UECBMCQ0ExFDASBgNVBAcTC1NhbnRhIENsYXJhMREwDwYDVQQKEwhE -YXRhU3RheDELMAkGA1UECxMCVEUxGDAWBgNVBAMTD1BoaWxpcCBUaG9tcHNvbjAe -Fw0xNDEwMDMxNTQ2NDdaFw0xNTAxMDExNTQ2NDdaMGoxCzAJBgNVBAYTAlRFMQsw -CQYDVQQIEwJDQTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExETAPBgNVBAoTCERhdGFT -dGF4MQswCQYDVQQLEwJURTEYMBYGA1UEAxMPUGhpbGlwIFRob21wc29uMIGfMA0G -CSqGSIb3DQEBAQUAA4GNADCBiQKBgQChGDwrhpQR0d+NoqilMgsBlR6A2Dd1oMyI -Ue42sU4tN63g5N4adasfasdfsWgnAkP332ok3YAuVbxytwEv2K9HrUSiokAiuinl -hhHA8CXTHt/1ItzzWj9uJ3Hneb+5lOkXVTZX7Y+q3aSdpx/HnZqn4i27DtLZF0z3 -LccWPWRinQIDAQABoyEwHzAdBgNVHQ4EFgQU9WJpUhgGTBBH4xZBCV7Y9YISCp4w -DQYJKoZIhvcNAQELBQADgYEAF6e8eVAjoZhfyJ+jW5mB0pXa2vr5b7VFQ45voNnc -GrB3aNbz/AWT7LCJw88+Y5SJITgwN/8o3ZY6Y3MyiqeQYGo9WxDSWb5AdZWFa03Z -+hrVDQuw1r118zIhdS4KYDCQM2JfWY32TwK0MNG/6BO876HfkDpcjCYzq8Gh0gEg -uOA= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver.crt b/tests/integration/long/ssl/python_driver.crt deleted file mode 100644 index 0a419f4eb1..0000000000 Binary files a/tests/integration/long/ssl/python_driver.crt and /dev/null differ diff --git a/tests/integration/long/ssl/python_driver.jks b/tests/integration/long/ssl/python_driver.jks deleted file mode 100644 index 9a0fd59f73..0000000000 Binary files a/tests/integration/long/ssl/python_driver.jks and /dev/null differ diff --git a/tests/integration/long/ssl/python_driver.key b/tests/integration/long/ssl/python_driver.key deleted file mode 100644 index afd73b298c..0000000000 --- a/tests/integration/long/ssl/python_driver.key +++ /dev/null @@ -1,34 +0,0 @@ -Bag Attributes - friendlyName: python_driver - localKeyID: 54 69 6D 65 20 31 34 33 35 33 33 33 34 30 34 33 33 32 -Key Attributes: ------BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: DES-EDE3-CBC,8A0BC9CFBBB36D47 - -J3Rh82LhsNdIdCV4KCp758VIJJnmedwtq/I9oxH5kY4XoUQjfNcvLGlEnbAUD6+N -mYnQ5XPDvD7iC19XvlA9gfaoWERq+zroGEP+e4dX1X5RlT6YQBJpJR8IW4DWngDM -Nv6CuaGFJWMH8QUvKlJyFOPOHBqbhsCRaxg3pOG3RyUFXpGPDV0ySUyp6poHE9KE -pEVif/SdS3AhV2sb4tyBS9sRZdH1eeCN4gY6k9PQWyNViAgUYAG5xWsE4fITa3qY -gisyzbOYU8ue2QvmjPJgieiKPQf+st/ZRV5eQUCdUgAfLEnULGJXRZ5kw7kMXL0X -gLaKFbGxv4pKQCDCZQq4GXIA/nmTy6cme+VPKwq3usm+GdxfdWQJjgG65+AFaut/ -XjGm1fvSQzWuzpesfLy57HMK+bBh1/UKjuQa3wAHtgPtJLtUSW+/qBnQRdBbl20C -dJtJXyyTlX6H8bQBIfBLc4ntUwS8fVd2jsYJRpCBY6HdtpfsZZ5gQnm1Nux4ksKn -rYYx3I/JpChr7AV7Yj/lwc3Zca/VJl16CjyWeRTQEvkl6GK958aIzj73HfXleZc6 -HGVfOgo2BLmOzY0ZCq/Wa4fnURRgrC3SusrT9mjVbID91oNYw4BjMEU53u0uxPC+ -rr6SwG2EUVawGTVK4XZw2DINCPP/wsKqf0xqA+sxArcTN/MEdLUBdf8oDntkj2jG -Oy0kwpjqhSvWo1DqYKZjV/wKT2SS18OMAW+0qplbHw1/FDGWK+OseD8MXwBo06a5 -LWRQXhf0kEXUQ+oNj3eahe/npHiNChR6mEiIbCuE3NAXPPXJNkhMuj2f5EqrOPfZ -jqbNiLfKKx7L5t6B8LXkdKGPqztcFlnB8rRF9Eqa8F4wiEg8MBLrPyxgd/uT+NIz -LdDgvUE+IkCwQoYoCU70ApiEOyQNacuSxwUiVWVyn9CJYXPM4Vlje7GDIDRR5Xp6 -zNf0ktNP46PsRqDlYG9hZWndj4PRaAqtatlEEm37rmyouVBe3rxcbL1b1zsH/p1I -eaGGTyZ8+iEiuEk4gCOmfmYmpE7H/DXlQvtDRblid/bEY64Uietx0HQ5yZwXZYi8 -hb4itke6xkgRQEIXVyQOdU88PEuA5yofEGoXkfdLgtdu3erPrVDc+nQTYrMWNacR -JQljfhAFJdjOw81Yd5PnFHAtxcxzqEkWv0TGQLL1VjJdinhI7q/fIPLJ76FtuGmt -zlxo/Jy1aaUgM/e485+7aoNSGi2/t6zGqGuotdUCO5epgrUHX+9fOJnnrYTG9ixp -FSHTT69y72khnw3eMP8NnOS3Lu+xLEzQHNbUDfB8uyVEX4pyA3FPVVqwIaeJDiPS -2x7Sl5KKwLbqPPKRFRC1qLsN4KcqeXBG+piTLPExdzsLbrU9JZMcaNmSmUabdg20 -SCwIuU2kHEpO7O7yNGeV9m0CGFUaoCAHVG70oXHxpVjAJbtgyoBkiwSxghCxXkfW -Mg+1B2k4Gk1WrLjIyasH6p0MLUJ7qLYN+c+wF7ms00F/w04rM6zUpkgnqsazpw6F -weUhpA8qY2vOJN6rsB4byaOUnd33xhAwcY/pIAcjW7UBjNmFMB1DQg== ------END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/python_driver.pem b/tests/integration/long/ssl/python_driver.pem deleted file mode 100644 index 83556fd9ce..0000000000 --- a/tests/integration/long/ssl/python_driver.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIEFPORBzANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MCAXDTE1MDYyNTE3MDAxOFoYDzIxMTUwNjAxMTcwMDE4WjCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7 -aaRySjTA1p9mPfPMPPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZ -crALGFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSLRxCVjsu6 -GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFHwefatSlh++uu7kY6zAQI -v74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUK -XRxwpQIDAQABoyEwHzAdBgNVHQ4EFgQUJwTYG8dcZDt7faalYwCHmG3jp3swDQYJKoZIhvcNAQEL -BQADggEBABtg3SLFUkcbISoZO4/UdHY2z4BTJZXt5uep9qIVQu7NospzsafgyGF0YAQJq0fLhBlB -DVx6IxIvDZUfzKdIVMYJTQh7ZJ7kdsdhcRIhKZK4Lko3iOwkWS0aXsbQP+hcXrwGViYIV6+Rrmle -LuxwexVfJ+wXCJcc4vvbecVsOs2+ms1w98cUXvVS1d9KpHo37LK1mRsnYPik3+CBeYXqa8FzMJc1 -dlC/dNwrCXYJZ1QMEpyaP4TI3fmkg8OJ3glZkQr6nz1TUMwMmAvudb79IrmQKBuO6k99DZFJC6Er -oh6ff8G/F5YY+dWEqsF0KqNhL9uwyrqG3CTX5Eocg2AGkWI= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver_bad.pem b/tests/integration/long/ssl/python_driver_bad.pem deleted file mode 100644 index 978d6c53f3..0000000000 --- a/tests/integration/long/ssl/python_driver_bad.pem +++ /dev/null @@ -1,19 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIEFPORBzANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MCAXDTE1MDYyNTE3MDAxOFoYDzIxMTUwNjAxMTcwMDE4WjCBhjELMAkGA1UEBhMCVVMxEzARBgNV -BAgTCkNhbGlmb3JuaWExFDASBgNVBAcTC1NhbnRhIENsYXJhMRYwFAYDVQQKEw1EYXRhU3RheCBJ -bmMuMRwwGgYDVQQLExNQeXRob24gRHJpdmVyIFRlc3RzMRYwFAYDVQQDEw1QeXRob24gRHJpdmVy -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7 -aaRySjTA1p9mPfPMPPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZ -crALGFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSLRxCVjsu6 -GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFHwefatSlh++uu7kY6zAQI -v74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUK -XRxwpQIDAQABoyE666666gNVHQ4EFgQUJwTYG8dcZDt7faalYwCHmG3jp3swDQYJKoZIhvcNAQEL -BQADggEBABtg3SLFUkcbISoZO4/UdHY2z4BTJZXt5uep9qIVQu7NospzsafgyGF0YAQJq0fLhBlB -DVx6IxIvDZUfzKdIVMYJTQh7ZJ7kdsdhcRIhKZK4Lko3iOwkWS0aXsbQP+hcXrwGViYIV6+Rrmle -LuxwexVfJ+wXCJcc4vvbecVsOs2+ms1w98cUXvVS1d9KpHo37LK1mRsnYPik3+CBeYXqa8FzMJc1 -dlC/dNwrCXYJZ1QMEpyaP4TI3fmkg8OJ3glZkQr6nz1TUMwMmAvudb79IrmQKBuO6k99DZFJC6Er -oh6ff8G/F5YY+dWEqsF0KqNhL9uwyrqG3CTX5Eocg2AGkWI= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/python_driver_no_pass.key b/tests/integration/long/ssl/python_driver_no_pass.key deleted file mode 100644 index 8dd14f84f0..0000000000 --- a/tests/integration/long/ssl/python_driver_no_pass.key +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAjUi6xfmieLqx9yD7HhkB6sjHfbS51xE7aaRySjTA1p9mPfPM -PPMZ0NIsiDsUlT7Fa+LWZU9cGuJBFg8YxjU5Eij7smFd0J4tawk51KudDdUZcrAL -GFC3WY7sEboMl8UHqV+kESPlNm5/JSNSYkNm1TMi9mHcB/Bg3RDpORRW/keMtBSL -RxCVjsu6GvKN8wuEfU/bTmI9aUjbFRCFunBX6QEJeU44BYEJXNAls+X8szBfVmFH -wefatSlh++uu7kY6zAQIv74PHMZ8w+mWmbjpxEsmSg+uljGCjQHjKTNSFBY9kWWh -2LBiTcZuEsQ9DK0J/+1tUa0s5vq6CjUKXRxwpQIDAQABAoIBAC3bpYQM+wdk0c79 -DYU/aLfkY5wRxSBhn38yuUYMyWrgYjdJoslFvuNg1MODKbMnpLzX6+8GS0cOmUGn -tMrhC50xYEEOCX1lWiib3gGBkoCi4pevPGqwCFMxaL54PQ4mDc6UFJTbqdJ5Gxva -0yrB5ebdqkN+kASjqU0X6Bt21qXB6BvwAgpIXSX8r+NoH2Z9dumSYD+bOwhXo+/b -FQ1wyLL78tDdlJ8KibwnTv9RtLQbALUinMEHyP+4Gp/t/JnxlcAfvEwggYBxFR1K -5sN8dMFbMZVNqNREXZyWCMQqPbKLhIHPHlNo5pJP7cUh9iVH4QwYNIbOqUza/aUx -z7DIISECgYEAvpAAdDiBExMOELz4+ku5Uk6wmVOMnAK6El4ijOXjJsOB4FB6M0A6 -THXlzLws0YLcoZ3Pm91z20rqmkv1VG+En27uKC1Dgqqd4DOQzMuPoPxzq/q2ozFH -V5U1a0tTmyynr3CFzQUJKLJs1pKKIp6HMiB48JWQc5q6ZaaomEnOiYsCgYEAvczB -Bwwf7oaZGhson1HdcYs5kUm9VkL/25dELUt6uq5AB5jjvfOYd7HatngNRCabUCgE -gcaNfJSwpbOEZ00AxKVSxGmyIP1YAlkVcSdfAPwGO6C1+V4EPHqYUW0AVHOYo7oB -0MCyLT6nSUNiHWyI7qSEwCP03SqyAKA1pDRUVI8CgYBt+bEpYYqsNW0Cn+yYlqcH -Jz6n3h3h03kLLKSH6AwlzOLhT9CWT1TV15ydgWPkLb+ize6Ip087mYq3LWsSJaHG -WUC8kxLJECo4v8mrRzdG0yr2b6SDnebsVsITf89qWGUVzLyLS4Kzp/VECCIMRK0F -ctQZFFffP8ae74WRDddSbQKBgQC7vZ9qEyo6zNUAp8Ck51t+BtNozWIFw7xGP/hm -PXUm11nqqecMa7pzG3BWcaXdtbqHrS3YGMi3ZHTfUxUzAU4zNb0LH+ndC/xURj4Z -cXJeDO01aiDWi5LxJ+snEAT1hGqF+WX2UcVtT741j/urU0KXnBDb5jU92A++4rps -tH5+LQKBgGHtOWD+ffKNw7IrVLhP16GmYoZZ05zh10d1eUa0ifgczjdAsuEH5/Aq -zK7MsDyPcQBH/pOwAcifWGEdXmn9hL6w5dn96ABfa8Qh9nXWrCE2OFD81PDU9Osd -wnwbTKlYWPBwdF7UCseKC7gXkUD6Ls0ADWJvrCI7AfQJv6jj6nnE ------END RSA PRIVATE KEY----- diff --git a/tests/integration/long/ssl/rootCa.crt b/tests/integration/long/ssl/rootCa.crt new file mode 100644 index 0000000000..a0a0ec73cf --- /dev/null +++ b/tests/integration/long/ssl/rootCa.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCzCCAfMCFCoTNYhIQpOXMBnAq8Bw72qfKwGLMA0GCSqGSIb3DQEBCwUAMEIx +CzAJBgNVBAYTAlVTMREwDwYDVQQKDAhkYXRhc3RheDEPMA0GA1UECwwGZmllbGRz +MQ8wDQYDVQQDDAZyb290Q2EwHhcNMjEwMzE3MTcwNTE2WhcNMzEwMzE1MTcwNTE2 +WjBCMQswCQYDVQQGEwJVUzERMA8GA1UECgwIZGF0YXN0YXgxDzANBgNVBAsMBmZp +ZWxkczEPMA0GA1UEAwwGcm9vdENhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEApFoQtNu0+XQuMBPle4WAJYIMR74HL15uk9ToKBqMEXL7ah3r23xTTeGr +NyUXicM6Owiup7DK27F4vni+MYKAn7L4uZ99mW0ATYNXBDLFB+wwy1JBk4Dw5+eZ +q9lz1TGK7uBvTOXCllOA2qxRqtMTl2aPy5OuciWQe794abwFqs5+1l9GEuzJGsp1 +P9L4yljbmijC8RmvDFAeUZoKRdKXw2G5kUOHqK9Aej5gLxIK920PezpgLxm0V/PD +ZAlwlsW0vT79RgZCF/vtKcKSLtFTHgPBNPPbkZmOdE7s/6KoAkORBV/9CIsKeTC3 +Y/YeYQ2+G0gxiq1RcMavPw8f58POTQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQA1 +MXBlk6u2oVBM+4SyYc2nsaHyerM+omUEysAUNFJq6S6i0pu32ULcusDfrnrIQoyR +xPJ/GSYqZkIDX0s9LvPVD6A6bnugR+Z6VfEniLkG1+TkFC+JMCblgJyaF/EbuayU +3iJX+uj7ikTySjMSDvXxOHik2i0aOh90B/351+sFnSPQrFDQ0XqxeG8s0d7EiLTV +wWJmsYglSeTo1vF3ilVRwjmHO9sX6cmQhRvRNmiQrdWaM3gLS5F6yoQ2UQQ3YdFp +quhYuNwy0Ip6ZpORHYtzkCKSanz/oUh17QWvi7aaJyqD5G5hWZgn3R4RCutoOHRS +TEJ+xzhY768rpsrrNUou +-----END CERTIFICATE----- diff --git a/tests/integration/long/ssl/server_cert.pem b/tests/integration/long/ssl/server_cert.pem deleted file mode 100644 index 7c96b96ade..0000000000 --- a/tests/integration/long/ssl/server_cert.pem +++ /dev/null @@ -1,13 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICbjCCAdegAwIBAgIEP/N06DANBgkqhkiG9w0BAQsFADBqMQswCQYDVQQGEwJURTELMAkGA1UE -CBMCQ0ExFDASBgNVBAcTC1NhbnRhIENsYXJhMREwDwYDVQQKEwhEYXRhU3RheDELMAkGA1UECxMC -VEUxGDAWBgNVBAMTD1BoaWxpcCBUaG9tcHNvbjAeFw0xNDEwMDMxNTQ2NDdaFw0xNTAxMDExNTQ2 -NDdaMGoxCzAJBgNVBAYTAlRFMQswCQYDVQQIEwJDQTEUMBIGA1UEBxMLU2FudGEgQ2xhcmExETAP -BgNVBAoTCERhdGFTdGF4MQswCQYDVQQLEwJURTEYMBYGA1UEAxMPUGhpbGlwIFRob21wc29uMIGf -MA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQChGDwrhpQR0d+NoqilMgsBlR6A2Dd1oMyIUe42sU4t -N63g5N44Ic4RpTiyaWgnAkP332ok3YAuVbxytwEv2K9HrUSiokAiuinlhhHA8CXTHt/1ItzzWj9u -J3Hneb+5lOkXVTZX7Y+q3aSdpx/HnZqn4i27DtLZF0z3LccWPWRinQIDAQABoyEwHzAdBgNVHQ4E -FgQU9WJpUhgGTBBH4xZBCV7Y9YISCp4wDQYJKoZIhvcNAQELBQADgYEAF6e8eVAjoZhfyJ+jW5mB -0pXa2vr5b7VFQ45voNncGrB3aNbz/AWT7LCJw88+Y5SJITgwN/8o3ZY6Y3MyiqeQYGo9WxDSWb5A -dZWFa03Z+hrVDQuw1r118zIhdS4KYDCQM2JfWY32TwK0MNG/6BO876HfkDpcjCYzq8Gh0gEguOA= ------END CERTIFICATE----- diff --git a/tests/integration/long/ssl/server_keystore.jks b/tests/integration/long/ssl/server_keystore.jks deleted file mode 100644 index 8125935516..0000000000 Binary files a/tests/integration/long/ssl/server_keystore.jks and /dev/null differ diff --git a/tests/integration/long/ssl/server_trust.jks b/tests/integration/long/ssl/server_trust.jks deleted file mode 100644 index feb0784a06..0000000000 Binary files a/tests/integration/long/ssl/server_trust.jks and /dev/null differ diff --git a/tests/integration/long/test_consistency.py b/tests/integration/long/test_consistency.py index 43734c7dc8..dfb30297f0 100644 --- a/tests/integration/long/test_consistency.py +++ b/tests/integration/long/test_consistency.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,31 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import struct, time, traceback, sys, logging +import logging +import struct +import sys +import time +import traceback from cassandra import ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, Unavailable -from cassandra.cluster import Cluster +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass - -from tests.integration.long.utils import (force_stop, create_schema, wait_for_down, wait_for_up, - start, CoordinatorStats) - -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +from tests.integration import use_singledc, execute_until_pass, TestCluster -ALL_CONSISTENCY_LEVELS = set([ - ConsistencyLevel.ANY, ConsistencyLevel.ONE, ConsistencyLevel.TWO, - ConsistencyLevel.QUORUM, ConsistencyLevel.THREE, - ConsistencyLevel.ALL, ConsistencyLevel.LOCAL_QUORUM, - ConsistencyLevel.EACH_QUORUM]) +from tests.integration.long.utils import ( + force_stop, create_schema, wait_for_down, wait_for_up, start, CoordinatorStats +) -MULTI_DC_CONSISTENCY_LEVELS = set([ - ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.EACH_QUORUM]) +import unittest +ALL_CONSISTENCY_LEVELS = { + ConsistencyLevel.ANY, ConsistencyLevel.ONE, ConsistencyLevel.TWO, ConsistencyLevel.QUORUM, + ConsistencyLevel.THREE, ConsistencyLevel.ALL, ConsistencyLevel.LOCAL_QUORUM, + ConsistencyLevel.EACH_QUORUM +} +MULTI_DC_CONSISTENCY_LEVELS = {ConsistencyLevel.LOCAL_QUORUM, ConsistencyLevel.EACH_QUORUM} SINGLE_DC_CONSISTENCY_LEVELS = ALL_CONSISTENCY_LEVELS - MULTI_DC_CONSISTENCY_LEVELS log = logging.getLogger(__name__) @@ -82,7 +83,7 @@ def _query(self, session, keyspace, count, consistency_level=ConsistencyLevel.ON break except (OperationTimedOut, ReadTimeout): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 time.sleep(1) @@ -127,11 +128,11 @@ def _assert_reads_fail(self, session, keyspace, consistency_levels): pass def _test_tokenaware_one_node_down(self, keyspace, rf, accepted): - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(TokenAwarePolicy(RoundRobinPolicy()))} + ) + session = cluster.connect(wait_for_all_pools=True) + wait_for_up(cluster, 1) wait_for_up(cluster, 2) create_schema(cluster, session, keyspace, replication_factor=rf) @@ -179,11 +180,11 @@ def test_rfthree_tokenaware_one_node_down(self): def test_rfthree_tokenaware_none_down(self): keyspace = 'test_rfthree_tokenaware_none_down' - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(TokenAwarePolicy(RoundRobinPolicy()))} + ) + session = cluster.connect(wait_for_all_pools=True) + wait_for_up(cluster, 1) wait_for_up(cluster, 2) create_schema(cluster, session, keyspace, replication_factor=3) @@ -203,11 +204,11 @@ def test_rfthree_tokenaware_none_down(self): cluster.shutdown() def _test_downgrading_cl(self, keyspace, rf, accepted): - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - default_retry_policy=DowngradingConsistencyRetryPolicy(), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() + cluster = TestCluster(execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(TokenAwarePolicy(RoundRobinPolicy()), + DowngradingConsistencyRetryPolicy()) + }) + session = cluster.connect(wait_for_all_pools=True) create_schema(cluster, session, keyspace, replication_factor=rf) self._insert(session, keyspace, 1) @@ -247,22 +248,22 @@ def test_rftwo_downgradingcl(self): def test_rfthree_roundrobin_downgradingcl(self): keyspace = 'test_rfthree_roundrobin_downgradingcl' - cluster = Cluster( - load_balancing_policy=RoundRobinPolicy(), - default_retry_policy=DowngradingConsistencyRetryPolicy(), - protocol_version=PROTOCOL_VERSION) - self.rfthree_downgradingcl(cluster, keyspace, True) + with TestCluster(execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(RoundRobinPolicy(), + DowngradingConsistencyRetryPolicy()) + }) as cluster: + self.rfthree_downgradingcl(cluster, keyspace, True) def test_rfthree_tokenaware_downgradingcl(self): keyspace = 'test_rfthree_tokenaware_downgradingcl' - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - default_retry_policy=DowngradingConsistencyRetryPolicy(), - protocol_version=PROTOCOL_VERSION) - self.rfthree_downgradingcl(cluster, keyspace, False) + with TestCluster(execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(TokenAwarePolicy(RoundRobinPolicy()), + DowngradingConsistencyRetryPolicy()) + }) as cluster: + self.rfthree_downgradingcl(cluster, keyspace, False) def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) create_schema(cluster, session, keyspace, replication_factor=2) self._insert(session, keyspace, count=12) @@ -306,3 +307,61 @@ def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): # instead we should create these elsewhere # def test_rfthree_downgradingcl_twodcs(self): # def test_rfthree_downgradingcl_twodcs_dcaware(self): + + +class ConnectivityTest(unittest.TestCase): + + def get_node_not_x(self, node_to_stop): + nodes = [1, 2, 3] + for num in nodes: + if num is not node_to_stop: + return num + + def test_pool_with_host_down(self): + """ + Test to ensure that cluster.connect() doesn't return prior to pools being initialized. + + This test will figure out which host our pool logic will connect to first. It then shuts that server down. + Previously the cluster.connect() would return prior to the pools being initialized, and the first queries would + return a no host exception + + @since 3.7.0 + @jira_ticket PYTHON-617 + @expected_result query should complete successfully + + @test_category connection + """ + + # find the first node, we will try create connections to, shut it down. + + # We will be shuting down a random house, so we need a complete contact list + all_contact_points = ["127.0.0.1", "127.0.0.2", "127.0.0.3"] + + # Connect up and find out which host will bet queries routed to to first + cluster = TestCluster() + cluster.connect(wait_for_all_pools=True) + hosts = cluster.metadata.all_hosts() + address = hosts[0].address + node_to_stop = int(address.split('.')[-1:][0]) + cluster.shutdown() + + # We now register a cluster that has it's Control Connection NOT on the node that we are shutting down. + # We do this so we don't miss the event + contact_point = '127.0.0.{0}'.format(self.get_node_not_x(node_to_stop)) + cluster = TestCluster(contact_points=[contact_point]) + cluster.connect(wait_for_all_pools=True) + try: + force_stop(node_to_stop) + wait_for_down(cluster, node_to_stop) + # Attempt a query against that node. It should complete + cluster2 = TestCluster(contact_points=all_contact_points) + session2 = cluster2.connect() + session2.execute("SELECT * FROM system.local") + finally: + cluster2.shutdown() + start(node_to_stop) + wait_for_up(cluster, node_to_stop) + cluster.shutdown() + + + diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index 7654fbcca7..c4751657e8 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,24 +14,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys,logging, traceback, time - -from cassandra import ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, ReadFailure, WriteFailure,\ - FunctionFailure -from cassandra.cluster import Cluster +import logging +import sys +import traceback +import time +from packaging.version import Version +from unittest.mock import Mock + +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy +from cassandra import ( + ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, ReadFailure, WriteFailure, + FunctionFailure, ProtocolVersion, +) +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.concurrent import execute_concurrent_with_args from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, get_cluster, setup_keyspace, remove_cluster, get_node -from mock import Mock +from tests.integration import ( + use_singledc, PROTOCOL_VERSION, get_cluster, setup_keyspace, remove_cluster, + get_node, start_cluster_wait_for_up, requiresmallclockgranularity, + local, CASSANDRA_VERSION, TestCluster) + -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest log = logging.getLogger(__name__) +@local def setup_module(): """ We need some custom setup for this module. All unit tests in this module @@ -42,9 +53,12 @@ def setup_module(): use_singledc(start=False) ccm_cluster = get_cluster() ccm_cluster.stop() - config_options = {'tombstone_failure_threshold': 2000, 'tombstone_warn_threshold': 1000} + config_options = { + 'tombstone_failure_threshold': 2000, + 'tombstone_warn_threshold': 1000, + } ccm_cluster.set_configuration_options(config_options) - ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + start_cluster_wait_for_up(ccm_cluster) setup_keyspace() @@ -63,13 +77,11 @@ def setUp(self): """ Test is skipped if run with native protocol version <4 """ - if PROTOCOL_VERSION < 4: raise unittest.SkipTest( "Native protocol 4,0+ is required for custom payloads, currently using %r" % (PROTOCOL_VERSION,)) - - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() self.session = self.cluster.connect() self.nodes_currently_failing = [] self.node1, self.node2, self.node3 = get_cluster().nodes.values() @@ -89,7 +101,7 @@ def execute_helper(self, session, query): return session.execute(query) except OperationTimedOut: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 @@ -102,7 +114,7 @@ def execute_concurrent_args_helper(self, session, query, params): return execute_concurrent_with_args(session, query, params, concurrency=50) except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 @@ -128,25 +140,32 @@ def setFailingNodes(self, failing_nodes, keyspace): # Ensure all nodes not on the list, but that are currently set to failing are enabled for node in self.nodes_currently_failing: if node not in failing_nodes: - node.stop(wait_other_notice=True, gently=False) + node.stop(wait_other_notice=True, gently=True) node.start(wait_for_binary_proto=True, wait_other_notice=True) self.nodes_currently_failing.remove(node) - def _perform_cql_statement(self, text, consistency_level, expected_exception): + def _perform_cql_statement(self, text, consistency_level, expected_exception, session=None): """ Simple helper method to preform cql statements and check for expected exception @param text CQl statement to execute @param consistency_level Consistency level at which it is to be executed @param expected_exception Exception expected to be throw or none """ + if session is None: + session = self.session statement = SimpleStatement(text) statement.consistency_level = consistency_level if expected_exception is None: - self.execute_helper(self.session, statement) + self.execute_helper(session, statement) else: - with self.assertRaises(expected_exception): - self.execute_helper(self.session, statement) + with self.assertRaises(expected_exception) as cm: + self.execute_helper(session, statement) + if ProtocolVersion.uses_error_code_map(PROTOCOL_VERSION): + if isinstance(cm.exception, ReadFailure): + self.assertEqual(list(cm.exception.error_code_map.values())[0], 1) + if isinstance(cm.exception, WriteFailure): + self.assertEqual(list(cm.exception.error_code_map.values())[0], 0) def test_write_failures_from_coordinator(self): """ @@ -157,8 +176,8 @@ def test_write_failures_from_coordinator(self): factor of the keyspace, and the consistency level, we will expect the coordinator to send WriteFailure, or not. - @since 2.6.0 - @jira_ticket PYTHON-238 + @since 2.6.0, 3.7.0 + @jira_ticket PYTHON-238, PYTHON-619 @expected_result Appropriate write failures from the coordinator @test_category queries:basic @@ -217,8 +236,8 @@ def test_tombstone_overflow_read_failure(self): from the coordinator. - @since 2.6.0 - @jira_ticket PYTHON-238 + @since 2.6.0, 3.7.0 + @jira_ticket PYTHON-238, PYTHON-619 @expected_result Appropriate write failures from the coordinator @test_category queries:basic @@ -237,7 +256,8 @@ def test_tombstone_overflow_read_failure(self): parameters = [(x,) for x in range(3000)] self.execute_concurrent_args_helper(self.session, statement, parameters) - statement = self.session.prepare("DELETE v1 FROM test3rf.test2 WHERE k = 1 AND v0 =?") + column = 'v1' if CASSANDRA_VERSION < Version('4.0') else '' + statement = self.session.prepare("DELETE {} FROM test3rf.test2 WHERE k = 1 AND v0 =?".format(column)) parameters = [(x,) for x in range(2001)] self.execute_concurrent_args_helper(self.session, statement, parameters) @@ -304,31 +324,39 @@ def test_user_function_failure(self): """, consistency_level=ConsistencyLevel.ALL, expected_exception=None) +@requiresmallclockgranularity class TimeoutTimerTest(unittest.TestCase): def setUp(self): """ Setup sessions and pause node1 """ - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect() + self.cluster = TestCluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + } + ) - # self.node1, self.node2, self.node3 = get_cluster().nodes.values() - self.node1 = get_node(1) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect() + self.session = self.cluster.connect(wait_for_all_pools=True) + + self.control_connection_host_number = 1 + self.node_to_stop = get_node(self.control_connection_host_number) ddl = ''' CREATE TABLE test3rf.timeout ( k int PRIMARY KEY, v int )''' self.session.execute(ddl) - self.node1.pause() + self.node_to_stop.pause() def tearDown(self): """ Shutdown cluster and resume node1 """ - self.node1.resume() + self.node_to_stop.resume() self.session.execute("DROP TABLE test3rf.timeout") self.cluster.shutdown() @@ -358,13 +386,14 @@ def test_async_timeouts(self): future.result() end_time = time.time() total_time = end_time-start_time - expected_time = self.session.default_timeout + expected_time = self.cluster.profile_manager.default.request_timeout # check timeout and ensure it's within a reasonable range self.assertAlmostEqual(expected_time, total_time, delta=.05) # Test with user defined timeout (Should be 1) + expected_time = 1 start_time = time.time() - future = self.session.execute_async(ss, timeout=1) + future = self.session.execute_async(ss, timeout=expected_time) mock_callback = Mock(return_value=None) mock_errorback = Mock(return_value=None) future.add_callback(mock_callback) @@ -374,16 +403,7 @@ def test_async_timeouts(self): future.result() end_time = time.time() total_time = end_time-start_time - expected_time = 1 # check timeout and ensure it's within a reasonable range self.assertAlmostEqual(expected_time, total_time, delta=.05) self.assertTrue(mock_errorback.called) self.assertFalse(mock_callback.called) - - - - - - - - diff --git a/tests/integration/long/test_ipv6.py b/tests/integration/long/test_ipv6.py index 618050fe42..e20f11cc9c 100644 --- a/tests/integration/long/test_ipv6.py +++ b/tests/integration/long/test_ipv6.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os, socket +import os, socket, errno from ccmlib import common -from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.cluster import NoHostAvailable from cassandra.io.asyncorereactor import AsyncoreConnection from tests import is_monkey_patched -from tests.integration import use_cluster, remove_cluster, PROTOCOL_VERSION +from tests.integration import use_cluster, remove_cluster, TestCluster if is_monkey_patched(): LibevConnection = -1 @@ -30,10 +32,7 @@ except ImportError: LibevConnection = None -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest # If more modules do IPV6 testing, this can be moved down to integration.__init__. @@ -65,7 +64,7 @@ def validate_host_viable(): # this is something ccm does when starting, but preemptively check to avoid # spinning up the cluster if it's not going to work try: - common.check_socket_available(('::1', 9042)) + common.assert_socket_available(('::1', 9042)) except: raise unittest.SkipTest('failed binding ipv6 loopback ::1 on 9042') @@ -75,8 +74,7 @@ class IPV6ConnectionTest(object): connection_class = None def test_connect(self): - cluster = Cluster(connection_class=self.connection_class, contact_points=['::1'], connect_timeout=10, - protocol_version=PROTOCOL_VERSION) + cluster = TestCluster(connection_class=self.connection_class, contact_points=['::1'], connect_timeout=10) session = cluster.connect() future = session.execute_async("SELECT * FROM system.local") future.result() @@ -84,17 +82,17 @@ def test_connect(self): cluster.shutdown() def test_error(self): - cluster = Cluster(connection_class=self.connection_class, contact_points=['::1'], port=9043, - connect_timeout=10, protocol_version=PROTOCOL_VERSION) - self.assertRaisesRegexp(NoHostAvailable, '\(\'Unable to connect.*%s.*::1\', 9043.*Connection refused.*' - % os.errno.ECONNREFUSED, cluster.connect) + cluster = TestCluster(connection_class=self.connection_class, contact_points=['::1'], port=9043, + connect_timeout=10) + self.assertRaisesRegex(NoHostAvailable, '\(\'Unable to connect.*%s.*::1\', 9043.*Connection refused.*' + % errno.ECONNREFUSED, cluster.connect) def test_error_multiple(self): if len(socket.getaddrinfo('localhost', 9043, socket.AF_UNSPEC, socket.SOCK_STREAM)) < 2: raise unittest.SkipTest('localhost only resolves one address') - cluster = Cluster(connection_class=self.connection_class, contact_points=['localhost'], port=9043, - connect_timeout=10, protocol_version=PROTOCOL_VERSION) - self.assertRaisesRegexp(NoHostAvailable, '\(\'Unable to connect.*Tried connecting to \[\(.*\(.*\].*Last error', + cluster = TestCluster(connection_class=self.connection_class, contact_points=['localhost'], port=9043, + connect_timeout=10) + self.assertRaisesRegex(NoHostAvailable, '\(\'Unable to connect.*Tried connecting to \[\(.*\(.*\].*Last error', cluster.connect) diff --git a/tests/integration/long/test_large_data.py b/tests/integration/long/test_large_data.py index acac09893d..8ff482271e 100644 --- a/tests/integration/long/test_large_data.py +++ b/tests/integration/long/test_large_data.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,16 +23,13 @@ import logging, sys, traceback, time from cassandra import ConsistencyLevel, OperationTimedOut, WriteTimeout -from cassandra.cluster import Cluster +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.query import dict_factory from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster from tests.integration.long.utils import create_schema -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest log = logging.getLogger(__name__) @@ -61,11 +60,10 @@ def setUp(self): self.keyspace = 'large_data' def make_session_and_keyspace(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster(execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(request_timeout=20, row_factory=dict_factory) + }) session = cluster.connect() - session.default_timeout = 20.0 # increase the default timeout - session.row_factory = dict_factory - create_schema(cluster, session, self.keyspace) return session @@ -82,7 +80,7 @@ def batch_futures(self, session, statement_generator): except (OperationTimedOut, WriteTimeout): ex_type, ex, tb = sys.exc_info() number_of_timeouts += 1 - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb time.sleep(1) except Empty: @@ -97,7 +95,7 @@ def batch_futures(self, session, statement_generator): except (OperationTimedOut, WriteTimeout): ex_type, ex, tb = sys.exc_info() number_of_timeouts += 1 - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb time.sleep(1) except Empty: @@ -119,7 +117,7 @@ def test_wide_rows(self): # Verify for i, row in enumerate(results): - self.assertEqual(row['i'], i) + self.assertAlmostEqual(row['i'], i, delta=3) session.cluster.shutdown() @@ -156,8 +154,8 @@ def test_wide_batch_rows(self): #If we timeout on insertion that's bad but it could be just slow underlying c* #Attempt to validate anyway, we will fail if we don't get the right data back. ex_type, ex, tb = sys.exc_info() - log.warn("Batch wide row insertion timed out, this may require additional investigation") - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("Batch wide row insertion timed out, this may require additional investigation") + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb # Verify diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py index f05253175b..7cb173643c 100644 --- a/tests/integration/long/test_loadbalancingpolicies.py +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,27 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import struct, time, logging, sys, traceback +import logging +import struct +import sys +import traceback +from cassandra import cqltypes from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \ WriteTimeout, WriteFailure -from cassandra.cluster import Cluster, NoHostAvailable, Session +from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.concurrent import execute_concurrent_with_args from cassandra.metadata import murmur3 -from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, - TokenAwarePolicy, WhiteListRoundRobinPolicy) +from cassandra.policies import ( + RoundRobinPolicy, DCAwareRoundRobinPolicy, + TokenAwarePolicy, WhiteListRoundRobinPolicy, + HostFilterPolicy +) from cassandra.query import SimpleStatement -from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION +from tests.integration import use_singledc, use_multidc, remove_cluster, TestCluster, greaterthanorequalcass40, notdse from tests.integration.long.utils import (wait_for_up, create_schema, CoordinatorStats, force_stop, wait_for_down, decommission, start, bootstrap, stop, IP_FORMAT) -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest log = logging.getLogger(__name__) @@ -40,16 +46,65 @@ class LoadBalancingPolicyTests(unittest.TestCase): def setUp(self): - remove_cluster() # clear ahead of test so it doesn't use one left in unknown state + remove_cluster() # clear ahead of test so it doesn't use one left in unknown state self.coordinator_stats = CoordinatorStats() self.prepared = None + self.probe_cluster = None + + def tearDown(self): + if self.probe_cluster: + self.probe_cluster.shutdown() @classmethod def teardown_class(cls): remove_cluster() + def _connect_probe_cluster(self): + if not self.probe_cluster: + # distinct cluster so we can see the status of nodes ignored by the LBP being tested + self.probe_cluster = TestCluster( + schema_metadata_enabled=False, + token_metadata_enabled=False, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=RoundRobinPolicy())} + ) + self.probe_session = self.probe_cluster.connect() + + def _wait_for_nodes_up(self, nodes, cluster=None): + log.debug('entered: _wait_for_nodes_up(nodes={ns}, ' + 'cluster={cs})'.format(ns=nodes, + cs=cluster)) + if not cluster: + log.debug('connecting to cluster') + self._connect_probe_cluster() + cluster = self.probe_cluster + for n in nodes: + wait_for_up(cluster, n) + + def _wait_for_nodes_down(self, nodes, cluster=None): + log.debug('entered: _wait_for_nodes_down(nodes={ns}, ' + 'cluster={cs})'.format(ns=nodes, + cs=cluster)) + if not cluster: + self._connect_probe_cluster() + cluster = self.probe_cluster + for n in nodes: + wait_for_down(cluster, n) + + def _cluster_session_with_lbp(self, lbp): + # create a cluster with no delay on events + + cluster = TestCluster(topology_event_refresh_window=0, status_event_refresh_window=0, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=lbp)}) + session = cluster.connect() + return cluster, session + def _insert(self, session, keyspace, count=12, consistency_level=ConsistencyLevel.ONE): + log.debug('entered _insert(' + 'session={session}, keyspace={keyspace}, ' + 'count={count}, consistency_level={consistency_level}' + ')'.format(session=session, keyspace=keyspace, count=count, + consistency_level=consistency_level)) session.execute('USE %s' % keyspace) ss = SimpleStatement('INSERT INTO cf(k, i) VALUES (0, 0)', consistency_level=consistency_level) @@ -57,10 +112,11 @@ def _insert(self, session, keyspace, count=12, while tries < 100: try: execute_concurrent_with_args(session, ss, [None] * count) + log.debug('Completed _insert on try #{}'.format(tries + 1)) return except (OperationTimedOut, WriteTimeout, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 @@ -68,11 +124,18 @@ def _insert(self, session, keyspace, count=12, def _query(self, session, keyspace, count=12, consistency_level=ConsistencyLevel.ONE, use_prepared=False): + log.debug('entered _query(' + 'session={session}, keyspace={keyspace}, ' + 'count={count}, consistency_level={consistency_level}, ' + 'use_prepared={use_prepared}' + ')'.format(session=session, keyspace=keyspace, count=count, + consistency_level=consistency_level, + use_prepared=use_prepared)) if use_prepared: query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace if not self.prepared or self.prepared.query_string != query_string: self.prepared = session.prepare(query_string) - self.prepared.consistency_level=consistency_level + self.prepared.consistency_level = consistency_level for i in range(count): tries = 0 while True: @@ -83,7 +146,7 @@ def _query(self, session, keyspace, count=12, break except (OperationTimedOut, ReadTimeout, ReadFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 else: @@ -101,13 +164,13 @@ def _query(self, session, keyspace, count=12, break except (OperationTimedOut, ReadTimeout, ReadFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 def test_token_aware_is_used_by_default(self): """ - Test for default loadbalacing policy + Test for default load balancing policy test_token_aware_is_used_by_default tests that the default loadbalancing policy is policies.TokenAwarePolicy. It creates a simple Cluster and verifies that the default loadbalancing policy is TokenAwarePolicy if the @@ -120,26 +183,21 @@ def test_token_aware_is_used_by_default(self): @test_category load_balancing:token_aware """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() + self.addCleanup(cluster.shutdown) if murmur3 is not None: - self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy)) + self.assertTrue(isinstance(cluster.profile_manager.default.load_balancing_policy, TokenAwarePolicy)) else: - self.assertTrue(isinstance(cluster.load_balancing_policy, DCAwareRoundRobinPolicy)) - - cluster.shutdown() + self.assertTrue(isinstance(cluster.profile_manager.default.load_balancing_policy, DCAwareRoundRobinPolicy)) def test_roundrobin(self): use_singledc() keyspace = 'test_roundrobin' - cluster = Cluster( - load_balancing_policy=RoundRobinPolicy(), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3) + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 4), cluster) create_schema(cluster, session, keyspace, replication_factor=3) self._insert(session, keyspace) self._query(session, keyspace) @@ -149,7 +207,7 @@ def test_roundrobin(self): self.coordinator_stats.assert_query_count_equals(self, 3, 4) force_stop(3) - wait_for_down(cluster, 3) + self._wait_for_nodes_down([3], cluster) self.coordinator_stats.reset_counts() self._query(session, keyspace) @@ -160,8 +218,8 @@ def test_roundrobin(self): decommission(1) start(3) - wait_for_down(cluster, 1) - wait_for_up(cluster, 3) + self._wait_for_nodes_down([1], cluster) + self._wait_for_nodes_up([3], cluster) self.coordinator_stats.reset_counts() self._query(session, keyspace) @@ -173,14 +231,9 @@ def test_roundrobin(self): def test_roundrobin_two_dcs(self): use_multidc([2, 2]) keyspace = 'test_roundrobin_two_dcs' - cluster = Cluster( - load_balancing_policy=RoundRobinPolicy(), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3, wait=False) - wait_for_up(cluster, 4) + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 5), cluster) create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) self._insert(session, keyspace) @@ -197,7 +250,7 @@ def test_roundrobin_two_dcs(self): # reset control connection self._insert(session, keyspace, count=1000) - wait_for_up(cluster, 5) + self._wait_for_nodes_up([5], cluster) self.coordinator_stats.reset_counts() self._query(session, keyspace) @@ -208,19 +261,12 @@ def test_roundrobin_two_dcs(self): self.coordinator_stats.assert_query_count_equals(self, 4, 3) self.coordinator_stats.assert_query_count_equals(self, 5, 3) - cluster.shutdown() - def test_roundrobin_two_dcs_2(self): use_multidc([2, 2]) keyspace = 'test_roundrobin_two_dcs_2' - cluster = Cluster( - load_balancing_policy=RoundRobinPolicy(), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3, wait=False) - wait_for_up(cluster, 4) + cluster, session = self._cluster_session_with_lbp(RoundRobinPolicy()) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 5), cluster) create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) self._insert(session, keyspace) @@ -237,7 +283,7 @@ def test_roundrobin_two_dcs_2(self): # reset control connection self._insert(session, keyspace, count=1000) - wait_for_up(cluster, 5) + self._wait_for_nodes_up([5], cluster) self.coordinator_stats.reset_counts() self._query(session, keyspace) @@ -248,20 +294,12 @@ def test_roundrobin_two_dcs_2(self): self.coordinator_stats.assert_query_count_equals(self, 4, 3) self.coordinator_stats.assert_query_count_equals(self, 5, 3) - cluster.shutdown() - def test_dc_aware_roundrobin_two_dcs(self): use_multidc([3, 2]) keyspace = 'test_dc_aware_roundrobin_two_dcs' - cluster = Cluster( - load_balancing_policy=DCAwareRoundRobinPolicy('dc1'), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3, wait=False) - wait_for_up(cluster, 4, wait=False) - wait_for_up(cluster, 5) + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc1')) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 6)) create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) self._insert(session, keyspace) @@ -273,20 +311,12 @@ def test_dc_aware_roundrobin_two_dcs(self): self.coordinator_stats.assert_query_count_equals(self, 4, 0) self.coordinator_stats.assert_query_count_equals(self, 5, 0) - cluster.shutdown() - def test_dc_aware_roundrobin_two_dcs_2(self): use_multidc([3, 2]) keyspace = 'test_dc_aware_roundrobin_two_dcs_2' - cluster = Cluster( - load_balancing_policy=DCAwareRoundRobinPolicy('dc2'), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3, wait=False) - wait_for_up(cluster, 4, wait=False) - wait_for_up(cluster, 5) + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc2')) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 6)) create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) self._insert(session, keyspace) @@ -298,19 +328,12 @@ def test_dc_aware_roundrobin_two_dcs_2(self): self.coordinator_stats.assert_query_count_equals(self, 4, 6) self.coordinator_stats.assert_query_count_equals(self, 5, 6) - cluster.shutdown() - def test_dc_aware_roundrobin_one_remote_host(self): use_multidc([2, 2]) keyspace = 'test_dc_aware_roundrobin_one_remote_host' - cluster = Cluster( - load_balancing_policy=DCAwareRoundRobinPolicy('dc2', used_hosts_per_remote_dc=1), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3, wait=False) - wait_for_up(cluster, 4) + cluster, session = self._cluster_session_with_lbp(DCAwareRoundRobinPolicy('dc2', used_hosts_per_remote_dc=1)) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 5)) create_schema(cluster, session, keyspace, replication_strategy=[2, 2]) self._insert(session, keyspace) @@ -323,7 +346,7 @@ def test_dc_aware_roundrobin_one_remote_host(self): self.coordinator_stats.reset_counts() bootstrap(5, 'dc1') - wait_for_up(cluster, 5) + self._wait_for_nodes_up([5]) self._query(session, keyspace) @@ -336,8 +359,7 @@ def test_dc_aware_roundrobin_one_remote_host(self): self.coordinator_stats.reset_counts() decommission(3) decommission(4) - wait_for_down(cluster, 3, wait=True) - wait_for_down(cluster, 4, wait=True) + self._wait_for_nodes_down([3, 4]) self._query(session, keyspace) @@ -350,7 +372,7 @@ def test_dc_aware_roundrobin_one_remote_host(self): self.coordinator_stats.reset_counts() decommission(5) - wait_for_down(cluster, 5, wait=True) + self._wait_for_nodes_down([5]) self._query(session, keyspace) @@ -364,7 +386,7 @@ def test_dc_aware_roundrobin_one_remote_host(self): self.coordinator_stats.reset_counts() decommission(1) - wait_for_down(cluster, 1, wait=True) + self._wait_for_nodes_down([1]) self._query(session, keyspace) @@ -383,8 +405,6 @@ def test_dc_aware_roundrobin_one_remote_host(self): except NoHostAvailable: pass - cluster.shutdown() - def test_token_aware(self): keyspace = 'test_token_aware' self.token_aware(keyspace) @@ -395,13 +415,9 @@ def test_token_aware_prepared(self): def token_aware(self, keyspace, use_prepared=False): use_singledc() - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3) + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 4), cluster) create_schema(cluster, session, keyspace, replication_factor=1) self._insert(session, keyspace) @@ -420,7 +436,7 @@ def token_aware(self, keyspace, use_prepared=False): self.coordinator_stats.reset_counts() force_stop(2) - wait_for_down(cluster, 2, wait=True) + self._wait_for_nodes_down([2], cluster) try: self._query(session, keyspace, use_prepared=use_prepared) @@ -432,7 +448,7 @@ def token_aware(self, keyspace, use_prepared=False): self.coordinator_stats.reset_counts() start(2) - wait_for_up(cluster, 2, wait=True) + self._wait_for_nodes_up([2], cluster) self._query(session, keyspace, use_prepared=use_prepared) @@ -442,7 +458,7 @@ def token_aware(self, keyspace, use_prepared=False): self.coordinator_stats.reset_counts() stop(2) - wait_for_down(cluster, 2, wait=True) + self._wait_for_nodes_down([2], cluster) try: self._query(session, keyspace, use_prepared=use_prepared) @@ -452,9 +468,9 @@ def token_aware(self, keyspace, use_prepared=False): self.coordinator_stats.reset_counts() start(2) - wait_for_up(cluster, 2, wait=True) + self._wait_for_nodes_up([2], cluster) decommission(2) - wait_for_down(cluster, 2, wait=True) + self._wait_for_nodes_down([2], cluster) self._query(session, keyspace, use_prepared=use_prepared) @@ -465,19 +481,13 @@ def token_aware(self, keyspace, use_prepared=False): self.assertEqual(results, set([0, 12])) self.coordinator_stats.assert_query_count_equals(self, 2, 0) - cluster.shutdown() - def test_token_aware_composite_key(self): use_singledc() keyspace = 'test_token_aware_composite_key' table = 'composite' - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3) + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 4), cluster) create_schema(cluster, session, keyspace, replication_factor=2) session.execute('CREATE TABLE %s (' @@ -490,23 +500,27 @@ def test_token_aware_composite_key(self): '(k1, k2, i) ' 'VALUES ' '(?, ?, ?)' % table) - session.execute(prepared.bind((1, 2, 3))) + bound = prepared.bind((1, 2, 3)) + result = session.execute(bound) + self.assertIn(result.response_future.attempted_hosts[0], + cluster.metadata.get_replicas(keyspace, bound.routing_key)) + + # There could be race condition with querying a node + # which doesn't yet have the data so we query one of + # the replicas + results = session.execute(SimpleStatement('SELECT * FROM %s WHERE k1 = 1 AND k2 = 2' % table, + routing_key=bound.routing_key)) + self.assertIn(results.response_future.attempted_hosts[0], + cluster.metadata.get_replicas(keyspace, bound.routing_key)) - results = session.execute('SELECT * FROM %s WHERE k1 = 1 AND k2 = 2' % table) self.assertTrue(results[0].i) - cluster.shutdown() - def test_token_aware_with_rf_2(self, use_prepared=False): use_singledc() keyspace = 'test_token_aware_with_rf_2' - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3) + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 4), cluster) create_schema(cluster, session, keyspace, replication_factor=2) self._insert(session, keyspace) @@ -518,7 +532,7 @@ def test_token_aware_with_rf_2(self, use_prepared=False): self.coordinator_stats.reset_counts() stop(2) - wait_for_down(cluster, 2, wait=True) + self._wait_for_nodes_down([2], cluster) self._query(session, keyspace) @@ -526,33 +540,170 @@ def test_token_aware_with_rf_2(self, use_prepared=False): self.coordinator_stats.assert_query_count_equals(self, 2, 0) self.coordinator_stats.assert_query_count_equals(self, 3, 12) - cluster.shutdown() - def test_token_aware_with_local_table(self): use_singledc() - cluster = Cluster( - load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), - protocol_version=PROTOCOL_VERSION) - session = cluster.connect() + cluster, session = self._cluster_session_with_lbp(TokenAwarePolicy(RoundRobinPolicy())) + self.addCleanup(cluster.shutdown) + self._wait_for_nodes_up(range(1, 4), cluster) p = session.prepare("SELECT * FROM system.local WHERE key=?") # this would blow up prior to 61b4fad r = session.execute(p, ('local',)) self.assertEqual(r[0].key, 'local') - cluster.shutdown() + def test_token_aware_with_shuffle_rf2(self): + """ + Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result the request are spread across the replicas, + when one of them is down, the requests target the available one + + @test_category policy + """ + keyspace = 'test_token_aware_with_rf_2' + cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=2) + self.addCleanup(cluster.shutdown) + + self._check_query_order_changes(session=session, keyspace=keyspace) + + # check TokenAwarePolicy still return the remaining replicas when one goes down + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + + def test_token_aware_with_shuffle_rf3(self): + """ + Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result the request are spread across the replicas, + when one of them is down, the requests target the other available ones + + @test_category policy + """ + keyspace = 'test_token_aware_with_rf_3' + cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=3) + self.addCleanup(cluster.shutdown) + + self._check_query_order_changes(session=session, keyspace=keyspace) + + # check TokenAwarePolicy still return the remaining replicas when one goes down + self.coordinator_stats.reset_counts() + stop(1) + self._wait_for_nodes_down([1], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + query_count_two = self.coordinator_stats.get_query_count(2) + query_count_three = self.coordinator_stats.get_query_count(3) + self.assertEqual(query_count_two + query_count_three, 12) + + self.coordinator_stats.reset_counts() + stop(2) + self._wait_for_nodes_down([2], cluster) + + self._query(session, keyspace) + + self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(self, 3, 12) + + @notdse + @greaterthanorequalcass40 + def test_token_aware_with_transient_replication(self): + """ + Test to validate that the token aware policy doesn't route any request to a transient node. + + @since 3.23 + @jira_ticket PYTHON-1207 + @expected_result the requests are spread across the 2 full replicas and + no other nodes are queried by the coordinator. + + @test_category policy + """ + # We can test this with a single dc when CASSANDRA-15670 is fixed + use_multidc([3, 3]) + + cluster, session = self._cluster_session_with_lbp( + TokenAwarePolicy(DCAwareRoundRobinPolicy(), shuffle_replicas=True) + ) + self.addCleanup(cluster.shutdown) + + session.execute("CREATE KEYSPACE test_tr WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': '3/1', 'dc2': '3/1'};") + session.execute("CREATE TABLE test_tr.users (id int PRIMARY KEY, username text) WITH read_repair ='NONE';") + for i in range(100): + session.execute("INSERT INTO test_tr.users (id, username) VALUES (%d, 'user');" % (i,)) + + query = session.prepare("SELECT * FROM test_tr.users WHERE id = ?") + for i in range(100): + f = session.execute_async(query, (i,), trace=True) + full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type.serialize(i, cluster.protocol_version)) + if h.datacenter == 'dc1'] + self.assertEqual(len(full_dc1_replicas), 2) + + f.result() + trace_hosts = [cluster.metadata.get_host(e.source) for e in f.get_query_trace().events] + + for h in f.attempted_hosts: + self.assertIn(h, full_dc1_replicas) + for h in trace_hosts: + self.assertIn(h, full_dc1_replicas) + + + def _set_up_shuffle_test(self, keyspace, replication_factor): + use_singledc() + cluster, session = self._cluster_session_with_lbp( + TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=True) + ) + self._wait_for_nodes_up(range(1, 4), cluster) + + create_schema(cluster, session, keyspace, replication_factor=replication_factor) + return cluster, session + + def _check_query_order_changes(self, session, keyspace): + LIMIT_TRIES, tried, query_counts = 20, 0, set() + + while len(query_counts) <= 1: + tried += 1 + if tried >= LIMIT_TRIES: + raise Exception("After {0} tries shuffle returned the same output".format(LIMIT_TRIES)) + + self._insert(session, keyspace) + self._query(session, keyspace) + + loop_qcs = (self.coordinator_stats.get_query_count(1), + self.coordinator_stats.get_query_count(2), + self.coordinator_stats.get_query_count(3)) + + query_counts.add(loop_qcs) + self.assertEqual(sum(loop_qcs), 12) + + # end the loop if we get more than one query ordering + self.coordinator_stats.reset_counts() def test_white_list(self): use_singledc() keyspace = 'test_white_list' - cluster = Cluster(('127.0.0.2',), - load_balancing_policy=WhiteListRoundRobinPolicy((IP_FORMAT % 2,)), - protocol_version=PROTOCOL_VERSION) + cluster = TestCluster( + contact_points=('127.0.0.2',), topology_event_refresh_window=0, status_event_refresh_window=0, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=WhiteListRoundRobinPolicy((IP_FORMAT % 2,)) + ) + } + ) + self.addCleanup(cluster.shutdown) session = cluster.connect() - wait_for_up(cluster, 1, wait=False) - wait_for_up(cluster, 2, wait=False) - wait_for_up(cluster, 3) + self._wait_for_nodes_up([1, 2, 3]) create_schema(cluster, session, keyspace) self._insert(session, keyspace) @@ -564,12 +715,12 @@ def test_white_list(self): # white list policy should not allow reconnecting to ignored hosts force_stop(3) - wait_for_down(cluster, 3) - self.assertFalse(cluster.metadata._hosts[IP_FORMAT % 3].is_currently_reconnecting()) + self._wait_for_nodes_down([3]) + self.assertFalse(cluster.metadata.get_host(IP_FORMAT % 3).is_currently_reconnecting()) self.coordinator_stats.reset_counts() force_stop(2) - time.sleep(10) + self._wait_for_nodes_down([2]) try: self._query(session, keyspace) @@ -577,4 +728,50 @@ def test_white_list(self): except NoHostAvailable: pass - cluster.shutdown() + def test_black_list_with_host_filter_policy(self): + """ + Test to validate removing certain hosts from the query plan with + HostFilterPolicy + @since 3.8 + @jira_ticket PYTHON-961 + @expected_result the excluded hosts are ignored + + @test_category policy + """ + use_singledc() + keyspace = 'test_black_list_with_hfp' + ignored_address = (IP_FORMAT % 2) + hfp = HostFilterPolicy( + child_policy=RoundRobinPolicy(), + predicate=lambda host: host.address != ignored_address + ) + cluster = TestCluster( + contact_points=(IP_FORMAT % 1,), + topology_event_refresh_window=0, + status_event_refresh_window=0, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=hfp)} + ) + self.addCleanup(cluster.shutdown) + session = cluster.connect() + self._wait_for_nodes_up([1, 2, 3]) + + self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()]) + + create_schema(cluster, session, keyspace) + self._insert(session, keyspace) + self._query(session, keyspace) + + # RoundRobin doesn't provide a gurantee on the order of the hosts + # so we will have that for 127.0.0.1 and 127.0.0.3 the count for one + # will be 4 and for the other 8 + first_node_count = self.coordinator_stats.get_query_count(1) + third_node_count = self.coordinator_stats.get_query_count(3) + self.assertEqual(first_node_count + third_node_count, 12) + self.assertTrue(first_node_count == 8 or first_node_count == 4) + + self.coordinator_stats.assert_query_count_equals(self, 2, 0) + + # policy should not allow reconnecting to ignored host + force_stop(2) + self._wait_for_nodes_down([2]) + self.assertFalse(cluster.metadata.get_host(ignored_address).is_currently_reconnecting()) diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py new file mode 100644 index 0000000000..751c6131ec --- /dev/null +++ b/tests/integration/long/test_policies.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra import ConsistencyLevel, Unavailable +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT + +from tests.integration import use_cluster, get_cluster, get_node, TestCluster + + +def setup_module(): + use_cluster('test_cluster', [4]) + + +class RetryPolicyTests(unittest.TestCase): + + @classmethod + def tearDownClass(cls): + cluster = get_cluster() + cluster.start(wait_for_binary_proto=True) # make sure other nodes are restarted + + def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): + """ + Tests for the default retry policy in combination with lightweight transactions. + + @since 3.17 + @jira_ticket PYTHON-1007 + @expected_result the query is retried with the default CL, not the serial one. + + @test_category policy + """ + ep = ExecutionProfile(consistency_level=ConsistencyLevel.ALL, + serial_consistency_level=ConsistencyLevel.SERIAL) + + cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ep}) + session = cluster.connect() + + session.execute("CREATE KEYSPACE test_retry_policy_cas WITH replication = {'class':'SimpleStrategy','replication_factor': 3};") + session.execute("CREATE TABLE test_retry_policy_cas.t (id int PRIMARY KEY, data text);") + session.execute('INSERT INTO test_retry_policy_cas.t ("id", "data") VALUES (%(0)s, %(1)s)', {'0': 42, '1': 'testing'}) + + get_node(2).stop() + get_node(4).stop() + + # before fix: cassandra.InvalidRequest: Error from server: code=2200 [Invalid query] message="SERIAL is not + # supported as conditional update commit consistency. ...."" + + # after fix: cassandra.Unavailable (expected since replicas are down) + with self.assertRaises(Unavailable) as cm: + session.execute("update test_retry_policy_cas.t set data = 'staging' where id = 42 if data ='testing'") + + exception = cm.exception + self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL) + self.assertEqual(exception.required_replicas, 2) + self.assertEqual(exception.alive_replicas, 1) diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index 87ca5f3c8a..4e6784a967 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,15 +17,13 @@ import logging from cassandra import ConsistencyLevel, AlreadyExists -from cassandra.cluster import Cluster from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass +from tests.integration import use_singledc, execute_until_pass, TestCluster + +import time -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest log = logging.getLogger(__name__) @@ -36,8 +36,8 @@ class SchemaTests(unittest.TestCase): @classmethod def setup_class(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - cls.session = cls.cluster.connect() + cls.cluster = TestCluster() + cls.session = cls.cluster.connect(wait_for_all_pools=True) @classmethod def teardown_class(cls): @@ -97,8 +97,8 @@ def test_for_schema_disagreements_same_keyspace(self): Tests for any schema disagreements using the same keyspace multiple times """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() + cluster = TestCluster() + session = cluster.connect(wait_for_all_pools=True) for i in range(30): try: @@ -113,6 +113,7 @@ def test_for_schema_disagreements_same_keyspace(self): execute_until_pass(session, "INSERT INTO test.cf (key, value) VALUES ({0}, {0})".format(j)) execute_until_pass(session, "DROP KEYSPACE test") + cluster.shutdown() def test_for_schema_disagreement_attribute(self): """ @@ -130,27 +131,33 @@ def test_for_schema_disagreement_attribute(self): @test_category schema """ # This should yield a schema disagreement - cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0.001) - session = cluster.connect() + cluster = TestCluster(max_schema_agreement_wait=0.001) + session = cluster.connect(wait_for_all_pools=True) - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") self.check_and_wait_for_agreement(session, rs, False) - rs = session.execute("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)") + rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) self.check_and_wait_for_agreement(session, rs, False) rs = session.execute("DROP KEYSPACE test_schema_disagreement") self.check_and_wait_for_agreement(session, rs, False) - + cluster.shutdown() + # These should have schema agreement - cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=100) + cluster = TestCluster(max_schema_agreement_wait=100) session = cluster.connect() - rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}") + rs = session.execute("CREATE KEYSPACE test_schema_disagreement WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}") self.check_and_wait_for_agreement(session, rs, True) - rs = session.execute("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)") + rs = session.execute(SimpleStatement("CREATE TABLE test_schema_disagreement.cf (key int PRIMARY KEY, value int)", + consistency_level=ConsistencyLevel.ALL)) self.check_and_wait_for_agreement(session, rs, True) rs = session.execute("DROP KEYSPACE test_schema_disagreement") self.check_and_wait_for_agreement(session, rs, True) + cluster.shutdown() def check_and_wait_for_agreement(self, session, rs, exepected): + # Wait for RESULT_KIND_SCHEMA_CHANGE message to arrive + time.sleep(1) self.assertEqual(rs.response_future.is_schema_agreed, exepected) if not rs.response_future.is_schema_agreed: session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index fc6a7066cd..5d86063d3e 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,70 +14,121 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -import os, sys, traceback, logging, ssl -from cassandra.cluster import Cluster, NoHostAvailable +import os, sys, traceback, logging, ssl, time, math, uuid +from cassandra.cluster import NoHostAvailable +from cassandra.connection import DefaultEndPoint from cassandra import ConsistencyLevel from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, get_cluster, remove_cluster + +from OpenSSL import SSL, crypto + +from tests.integration import ( + get_cluster, remove_cluster, use_single_node, start_cluster_wait_for_up, EVENT_LOOP_MANAGER, TestCluster +) + +if not hasattr(ssl, 'match_hostname'): + try: + from ssl import match_hostname + ssl.match_hostname = match_hostname + except ImportError: + pass # tests will fail log = logging.getLogger(__name__) DEFAULT_PASSWORD = "cassandra" # Server keystore trust store locations -SERVER_KEYSTORE_PATH = "tests/integration/long/ssl/server_keystore.jks" -SERVER_TRUSTSTORE_PATH = "tests/integration/long/ssl/server_trust.jks" +SERVER_KEYSTORE_PATH = os.path.abspath("tests/integration/long/ssl/127.0.0.1.keystore") +SERVER_TRUSTSTORE_PATH = os.path.abspath("tests/integration/long/ssl/cassandra.truststore") # Client specific keys/certs -CLIENT_CA_CERTS = 'tests/integration/long/ssl/driver_ca_cert.pem' -DRIVER_KEYFILE = "tests/integration/long/ssl/python_driver_no_pass.key" -DRIVER_CERTFILE = "tests/integration/long/ssl/python_driver.pem" -DRIVER_CERTFILE_BAD = "tests/integration/long/ssl/python_driver_bad.pem" +CLIENT_CA_CERTS = os.path.abspath("tests/integration/long/ssl/rootCa.crt") +DRIVER_KEYFILE = os.path.abspath("tests/integration/long/ssl/client.key") +DRIVER_KEYFILE_ENCRYPTED = os.path.abspath("tests/integration/long/ssl/client_encrypted.key") +DRIVER_CERTFILE = os.path.abspath("tests/integration/long/ssl/client.crt_signed") +DRIVER_CERTFILE_BAD = os.path.abspath("tests/integration/long/ssl/client_bad.key") + +USES_PYOPENSSL = "twisted" in EVENT_LOOP_MANAGER or "eventlet" in EVENT_LOOP_MANAGER +if "twisted" in EVENT_LOOP_MANAGER: + import OpenSSL + ssl_version = OpenSSL.SSL.TLSv1_2_METHOD + verify_certs = {'cert_reqs': SSL.VERIFY_PEER, + 'check_hostname': True} +else: + ssl_version = ssl.PROTOCOL_TLS + verify_certs = {'cert_reqs': ssl.CERT_REQUIRED, + 'check_hostname': True} + + +def verify_callback(connection, x509, errnum, errdepth, ok): + return ok def setup_cluster_ssl(client_auth=False): """ We need some custom setup for this module. This will start the ccm cluster with basic - ssl connectivity, and client authenticiation if needed. + ssl connectivity, and client authentication if needed. """ - use_singledc(start=False) + use_single_node(start=False) ccm_cluster = get_cluster() ccm_cluster.stop() - # Fetch the absolute path to the keystore for ccm. - abs_path_server_keystore_path = os.path.abspath(SERVER_KEYSTORE_PATH) - # Configure ccm to use ssl. - config_options = {'client_encryption_options': {'enabled': True, - 'keystore': abs_path_server_keystore_path, + 'keystore': SERVER_KEYSTORE_PATH, 'keystore_password': DEFAULT_PASSWORD}} if(client_auth): - abs_path_server_truststore_path = os.path.abspath(SERVER_TRUSTSTORE_PATH) client_encyrption_options = config_options['client_encryption_options'] client_encyrption_options['require_client_auth'] = True - client_encyrption_options['truststore'] = abs_path_server_truststore_path + client_encyrption_options['truststore'] = SERVER_TRUSTSTORE_PATH client_encyrption_options['truststore_password'] = DEFAULT_PASSWORD ccm_cluster.set_configuration_options(config_options) - ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + start_cluster_wait_for_up(ccm_cluster) -def teardown_module(): - """ - The rest of the tests don't need ssl enabled, remove the cluster so as to not interfere with other tests. - """ +def validate_ssl_options(**kwargs): + ssl_options = kwargs.get('ssl_options', None) + ssl_context = kwargs.get('ssl_context', None) + hostname = kwargs.get('hostname', '127.0.0.1') - ccm_cluster = get_cluster() - ccm_cluster.stop() - remove_cluster() + # find absolute path to client CA_CERTS + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = TestCluster( + contact_points=[DefaultEndPoint(hostname)], + ssl_options=ssl_options, + ssl_context=ssl_context + ) + session = cluster.connect(wait_for_all_pools=True) + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + # attempt a few simple commands. + insert_keyspace = """CREATE KEYSPACE ssltest + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} + """ + statement = SimpleStatement(insert_keyspace) + statement.consistency_level = 3 + session.execute(statement) + + drop_keyspace = "DROP KEYSPACE ssltest" + statement = SimpleStatement(drop_keyspace) + statement.consistency_level = ConsistencyLevel.ANY + session.execute(statement) + + cluster.shutdown() class SSLConnectionTests(unittest.TestCase): @@ -84,6 +137,12 @@ class SSLConnectionTests(unittest.TestCase): def setUpClass(cls): setup_cluster_ssl() + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + def test_can_connect_with_ssl_ca(self): """ Test to validate that we are able to connect to a cluster using ssl. @@ -101,38 +160,66 @@ def test_can_connect_with_ssl_ca(self): """ # find absolute path to client CA_CERTS - abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': CLIENT_CA_CERTS,'ssl_version': ssl_version} + validate_ssl_options(ssl_options=ssl_options) + + def test_can_connect_with_ssl_long_running(self): + """ + Test to validate that long running ssl connections continue to function past thier timeout window + + @since 3.6.0 + @jira_ticket PYTHON-600 + @expected_result The client can connect via SSL and preform some basic operations over a period of longer then a minute + + @test_category connection:ssl + """ + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path, + 'ssl_version': ssl_version} tries = 0 while True: if tries > 5: raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") try: - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1}) - session = cluster.connect() + cluster = TestCluster(ssl_options=ssl_options) + session = cluster.connect(wait_for_all_pools=True) break except Exception: ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb tries += 1 # attempt a few simple commands. - insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """ - statement = SimpleStatement(insert_keyspace) - statement.consistency_level = 3 - session.execute(statement) - drop_keyspace = "DROP KEYSPACE ssltest" - statement = SimpleStatement(drop_keyspace) - statement.consistency_level = ConsistencyLevel.ANY - session.execute(statement) + for i in range(8): + rs = session.execute("SELECT * FROM system.local") + time.sleep(10) cluster.shutdown() + def test_can_connect_with_ssl_ca_host_match(self): + """ + Test to validate that we are able to connect to a cluster using ssl, and host matching + + test_can_connect_with_ssl_ca_host_match performs a simple sanity check to ensure that we can connect to a cluster with ssl + authentication via simple server-side shared certificate authority. It also validates that the host ip matches what is expected + + @since 3.3 + @jira_ticket PYTHON-296 + @expected_result The client can connect via SSL and preform some basic operations, with check_hostname specified + + @test_category connection:ssl + """ + + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version} + ssl_options.update(verify_certs) + + validate_ssl_options(ssl_options=ssl_options) + class SSLConnectionAuthTests(unittest.TestCase): @@ -140,6 +227,12 @@ class SSLConnectionAuthTests(unittest.TestCase): def setUpClass(cls): setup_cluster_ssl(client_auth=True) + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + def test_can_connect_with_ssl_client_auth(self): """ Test to validate that we can connect to a C* cluster that has client_auth enabled. @@ -154,44 +247,34 @@ def test_can_connect_with_ssl_client_auth(self): @test_category connection:ssl """ - # Need to get absolute paths for certs/key - abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) - abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) - abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version, + 'keyfile': DRIVER_KEYFILE, + 'certfile': DRIVER_CERTFILE} + validate_ssl_options(ssl_options=ssl_options) - tries = 0 - while True: - if tries > 5: - raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") - try: - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1, - 'keyfile': abs_driver_keyfile, - 'certfile': abs_driver_certfile}) + def test_can_connect_with_ssl_client_auth_host_name(self): + """ + Test to validate that we can connect to a C* cluster that has client_auth enabled, and hostmatching - session = cluster.connect() - break - except Exception: - ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) - del tb - tries += 1 + This test will setup and use a c* cluster that has client authentication enabled. It will then attempt + to connect using valid client keys, and certs (that are in the server's truststore), and attempt to preform some + basic operations, with check_hostname specified + @jira_ticket PYTHON-296 + @since 3.3 - # attempt a few simple commands. + @expected_result The client can connect via SSL and preform some basic operations - insert_keyspace = """CREATE KEYSPACE ssltest - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} - """ - statement = SimpleStatement(insert_keyspace) - statement.consistency_level = 3 - session.execute(statement) + @test_category connection:ssl + """ - drop_keyspace = "DROP KEYSPACE ssltest" - statement = SimpleStatement(drop_keyspace) - statement.consistency_level = ConsistencyLevel.ANY - session.execute(statement) + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version, + 'keyfile': DRIVER_KEYFILE, + 'certfile': DRIVER_CERTFILE} + ssl_options.update(verify_certs) - cluster.shutdown() + validate_ssl_options(ssl_options=ssl_options) def test_cannot_connect_without_client_auth(self): """ @@ -206,17 +289,16 @@ def test_cannot_connect_without_client_auth(self): @test_category connection:ssl """ - abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1}) - # attempt to connect and expect an exception + cluster = TestCluster(ssl_options={'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version}) - with self.assertRaises(NoHostAvailable) as context: + with self.assertRaises(NoHostAvailable) as _: cluster.connect() + cluster.shutdown() def test_cannot_connect_with_bad_client_auth(self): """ - Test to validate that we cannot connect with invalid client auth. + Test to validate that we cannot connect with invalid client auth. This test will use bad keys/certs to preform client authentication. It will then attempt to connect to a server that has client authentication enabled. @@ -228,14 +310,194 @@ def test_cannot_connect_with_bad_client_auth(self): @test_category connection:ssl """ - # Setup absolute paths to key/cert files - abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) - abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE) - abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE_BAD) - - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options={'ca_certs': abs_path_ca_cert_path, - 'ssl_version': ssl.PROTOCOL_TLSv1, - 'keyfile': abs_driver_keyfile, - 'certfile': abs_driver_certfile}) - with self.assertRaises(NoHostAvailable) as context: + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version, + 'keyfile': DRIVER_KEYFILE} + + if not USES_PYOPENSSL: + # I don't set the bad certfile for pyopenssl because it hangs + ssl_options['certfile'] = DRIVER_CERTFILE_BAD + + cluster = TestCluster( + ssl_options={'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version, + 'keyfile': DRIVER_KEYFILE} + ) + + with self.assertRaises(NoHostAvailable) as _: cluster.connect() + cluster.shutdown() + + def test_cannot_connect_with_invalid_hostname(self): + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version, + 'keyfile': DRIVER_KEYFILE, + 'certfile': DRIVER_CERTFILE} + ssl_options.update(verify_certs) + + with self.assertRaises(Exception): + validate_ssl_options(ssl_options=ssl_options, hostname='localhost') + + +class SSLSocketErrorTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + setup_cluster_ssl() + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_ssl_want_write_errors_are_retried(self): + """ + Test that when a socket receives a WANT_WRITE error, the message chunk sending is retried. + + @since 3.17.0 + @jira_ticket PYTHON-891 + @expected_result The query is executed successfully + + @test_category connection:ssl + """ + ssl_options = {'ca_certs': CLIENT_CA_CERTS, + 'ssl_version': ssl_version} + cluster = TestCluster(ssl_options=ssl_options) + session = cluster.connect(wait_for_all_pools=True) + try: + session.execute('drop keyspace ssl_error_test') + except: + pass + session.execute( + "CREATE KEYSPACE ssl_error_test WITH replication = {'class':'SimpleStrategy','replication_factor':1};") + session.execute("CREATE TABLE ssl_error_test.big_text (id uuid PRIMARY KEY, data text);") + + params = { + '0': uuid.uuid4(), + '1': "0" * int(math.pow(10, 7)) + } + + session.execute('INSERT INTO ssl_error_test.big_text ("id", "data") VALUES (%(0)s, %(1)s)', params) + + +class SSLConnectionWithSSLContextTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_cluster_ssl() + + @classmethod + def tearDownClass(cls): + ccm_cluster = get_cluster() + ccm_cluster.stop() + remove_cluster() + + def test_can_connect_with_sslcontext_certificate(self): + """ + Test to validate that we are able to connect to a cluster using a SSLContext. + + @since 3.17.0 + @jira_ticket PYTHON-995 + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + if USES_PYOPENSSL: + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + else: + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + ssl_context.verify_mode = ssl.CERT_REQUIRED + validate_ssl_options(ssl_context=ssl_context) + + def test_can_connect_with_ssl_client_auth_password_private_key(self): + """ + Identical test to SSLConnectionAuthTests.test_can_connect_with_ssl_client_auth, + the only difference is that the DRIVER_KEYFILE is encrypted with a password. + + @since 3.17.0 + @jira_ticket PYTHON-995 + @expected_result The client can connect via SSL and preform some basic operations + + @test_category connection:ssl + """ + abs_driver_keyfile = os.path.abspath(DRIVER_KEYFILE_ENCRYPTED) + abs_driver_certfile = os.path.abspath(DRIVER_CERTFILE) + ssl_options = {} + + if USES_PYOPENSSL: + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.use_certificate_file(abs_driver_certfile) + with open(abs_driver_keyfile) as keyfile: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b'cassandra') + ssl_context.use_privatekey(key) + ssl_context.set_verify(SSL.VERIFY_NONE, verify_callback) + else: + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.load_cert_chain(certfile=abs_driver_certfile, + keyfile=abs_driver_keyfile, + password="cassandra") + ssl_context.verify_mode = ssl.CERT_NONE + validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options) + + def test_can_connect_with_ssl_context_ca_host_match(self): + """ + Test to validate that we are able to connect to a cluster using a SSLContext + using client auth, an encrypted keyfile, and host matching + """ + ssl_options = {} + if USES_PYOPENSSL: + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.use_certificate_file(DRIVER_CERTFILE) + with open(DRIVER_KEYFILE_ENCRYPTED) as keyfile: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b'cassandra') + ssl_context.use_privatekey(key) + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + ssl_options["check_hostname"] = True + else: + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + ssl_context.load_cert_chain( + certfile=DRIVER_CERTFILE, + keyfile=DRIVER_KEYFILE_ENCRYPTED, + password="cassandra", + ) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_options["check_hostname"] = True + validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options) + + def test_cannot_connect_ssl_context_with_invalid_hostname(self): + ssl_options = {} + if USES_PYOPENSSL: + ssl_context = SSL.Context(SSL.TLSv1_2_METHOD) + ssl_context.use_certificate_file(DRIVER_CERTFILE) + with open(DRIVER_KEYFILE_ENCRYPTED) as keyfile: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, keyfile.read(), b"cassandra") + ssl_context.use_privatekey(key) + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + ssl_options["check_hostname"] = True + else: + ssl_context = ssl.SSLContext(ssl_version) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.load_verify_locations(CLIENT_CA_CERTS) + ssl_context.load_cert_chain( + certfile=DRIVER_CERTFILE, + keyfile=DRIVER_KEYFILE_ENCRYPTED, + password="cassandra", + ) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_options["check_hostname"] = True + with self.assertRaises(Exception): + validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options, hostname="localhost") + + @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") + def test_can_connect_with_sslcontext_default_context(self): + """ + Test to validate that we are able to connect to a cluster using a SSLContext created from create_default_context(). + @expected_result The client can connect via SSL and preform some basic operations + @test_category connection:ssl + """ + ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) + validate_ssl_options(ssl_context=ssl_context) diff --git a/tests/integration/long/test_topology_change.py b/tests/integration/long/test_topology_change.py new file mode 100644 index 0000000000..5b12eef28c --- /dev/null +++ b/tests/integration/long/test_topology_change.py @@ -0,0 +1,48 @@ +from unittest import TestCase + +from cassandra.policies import HostStateListener +from tests.integration import get_node, use_cluster, local, TestCluster +from tests.integration.long.utils import decommission +from tests.util import wait_until + + +class StateListener(HostStateListener): + def __init__(self): + self.downed_host = None + self.removed_host = None + + def on_remove(self, host): + self.removed_host = host + + def on_up(self, host): + pass + + def on_down(self, host): + self.downed_host = host + + def on_add(self, host): + pass + + +class TopologyChangeTests(TestCase): + @local + def test_removed_node_stops_reconnecting(self): + """ Ensure we stop reconnecting after a node is removed. PYTHON-1181 """ + use_cluster("test_down_then_removed", [3], start=True) + + state_listener = StateListener() + cluster = TestCluster() + self.addCleanup(cluster.shutdown) + cluster.register_listener(state_listener) + session = cluster.connect(wait_for_all_pools=True) + + get_node(3).nodetool("disablebinary") + + wait_until(condition=lambda: state_listener.downed_host is not None, delay=2, max_attempts=50) + self.assertTrue(state_listener.downed_host.is_currently_reconnecting()) + + decommission(3) + + wait_until(condition=lambda: state_listener.removed_host is not None, delay=2, max_attempts=50) + self.assertIs(state_listener.downed_host, state_listener.removed_host) # Just a sanity check + self.assertFalse(state_listener.removed_host.is_currently_reconnecting()) diff --git a/tests/integration/long/utils.py b/tests/integration/long/utils.py index f850d6a1a8..cdbb177ec4 100644 --- a/tests/integration/long/utils.py +++ b/tests/integration/long/utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,16 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function import logging import time from collections import defaultdict -from ccmlib.node import Node +from packaging.version import Version -from cassandra.query import named_tuple_factory +from tests.integration import (get_node, get_cluster, wait_for_node_socket, + DSE_VERSION, CASSANDRA_VERSION) -from tests.integration import get_node, get_cluster IP_FORMAT = '127.0.0.%s' @@ -34,6 +35,7 @@ def __init__(self): self.coordinator_counts = defaultdict(int) def add_coordinator(self, future): + log.debug('adding coordinator from {}'.format(future)) future.result() coordinator = future._current_host.address self.coordinator_counts[coordinator] += 1 @@ -57,9 +59,6 @@ def assert_query_count_equals(self, testcase, node, expected): def create_schema(cluster, session, keyspace, simple_strategy=True, replication_factor=1, replication_strategy=None): - row_factory = session.row_factory - session.row_factory = named_tuple_factory - if keyspace in cluster.metadata.keyspaces.keys(): session.execute('DROP KEYSPACE %s' % keyspace, timeout=20) @@ -79,8 +78,6 @@ def create_schema(cluster, session, keyspace, simple_strategy=True, session.execute(ddl % keyspace, timeout=10) session.execute('USE %s' % keyspace) - session.row_factory = row_factory - def start(node): get_node(node).start() @@ -97,54 +94,73 @@ def force_stop(node): def decommission(node): - get_node(node).decommission() + if (DSE_VERSION and DSE_VERSION >= Version("5.1")) or CASSANDRA_VERSION >= Version("4.0-a"): + # CASSANDRA-12510 + get_node(node).decommission(force=True) + else: + get_node(node).decommission() get_node(node).stop() def bootstrap(node, data_center=None, token=None): - node_instance = Node('node%s' % node, - get_cluster(), - auto_bootstrap=False, - thrift_interface=(IP_FORMAT % node, 9160), - storage_interface=(IP_FORMAT % node, 7000), - jmx_port=str(7000 + 100 * node), - remote_debug_port=0, - initial_token=token if token else node * 10) - get_cluster().add(node_instance, is_seed=False, data_center=data_center) + log.debug('called bootstrap(' + 'node={node}, data_center={data_center}, ' + 'token={token})') + cluster = get_cluster() + # for now assumes cluster has at least one node + node_type = type(next(iter(cluster.nodes.values()))) + node_instance = node_type( + 'node%s' % node, + cluster, + auto_bootstrap=False, + thrift_interface=(IP_FORMAT % node, 9160), + storage_interface=(IP_FORMAT % node, 7000), + binary_interface=(IP_FORMAT % node, 9042), + jmx_port=str(7000 + 100 * node), + remote_debug_port=0, + initial_token=token if token else node * 10 + ) + cluster.add(node_instance, is_seed=False, data_center=data_center) try: - start(node) - except: + node_instance.start() + except Exception as e0: + log.debug('failed 1st bootstrap attempt with: \n{}'.format(e0)) # Try only twice try: - start(node) - except: + node_instance.start() + except Exception as e1: + log.debug('failed 2nd bootstrap attempt with: \n{}'.format(e1)) log.error('Added node failed to start twice.') + raise e1 def ring(node): - print('From node%s:' % node) get_node(node).nodetool('ring') -def wait_for_up(cluster, node, wait=True): +def wait_for_up(cluster, node): tries = 0 + addr = IP_FORMAT % node while tries < 100: - host = cluster.metadata.get_host(IP_FORMAT % node) + host = cluster.metadata.get_host(addr) if host and host.is_up: + wait_for_node_socket(get_node(node), 60) log.debug("Done waiting for node %s to be up", node) return else: - log.debug("Host is still marked down, waiting") + log.debug("Host {} is still marked down, waiting".format(addr)) tries += 1 time.sleep(1) - raise RuntimeError("Host {0} is not up after 100 attempts".format(IP_FORMAT.format(node))) + # todo: don't mix string interpolation methods in the same package + raise RuntimeError("Host {0} is not up after {1} attempts".format(addr, tries)) -def wait_for_down(cluster, node, wait=True): +def wait_for_down(cluster, node): log.debug("Waiting for node %s to be down", node) tries = 0 + addr = IP_FORMAT % node while tries < 100: host = cluster.metadata.get_host(IP_FORMAT % node) if not host or not host.is_up: @@ -155,4 +171,4 @@ def wait_for_down(cluster, node, wait=True): tries += 1 time.sleep(1) - raise RuntimeError("Host {0} is not down after 100 attempts".format(IP_FORMAT.format(node))) + raise RuntimeError("Host {0} is not down after {1} attempts".format(addr, tries)) diff --git a/tests/integration/simulacron/__init__.py b/tests/integration/simulacron/__init__.py new file mode 100644 index 0000000000..c959fd6e08 --- /dev/null +++ b/tests/integration/simulacron/__init__.py @@ -0,0 +1,82 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import unittest + +from tests.integration import requiredse, CASSANDRA_VERSION, DSE_VERSION, SIMULACRON_JAR, PROTOCOL_VERSION +from tests.integration.simulacron.utils import ( + clear_queries, + start_and_prime_singledc, + stop_simulacron, + start_and_prime_cluster_defaults, +) + +from cassandra.cluster import Cluster + +from packaging.version import Version + + +PROTOCOL_VERSION = min(4, PROTOCOL_VERSION if (DSE_VERSION is None or DSE_VERSION >= Version('5.0')) else 3) + + +def teardown_package(): + stop_simulacron() + + +class SimulacronBase(unittest.TestCase): + def tearDown(self): + clear_queries() + stop_simulacron() + + +class SimulacronCluster(SimulacronBase): + + cluster, connect = None, True + + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + start_and_prime_singledc() + if cls.connect: + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + if cls.cluster: + cls.cluster.shutdown() + stop_simulacron() + + +@requiredse +class DseSimulacronCluster(SimulacronBase): + + simulacron_cluster = None + cluster, connect = None, True + nodes_per_dc = 1 + + @classmethod + def setUpClass(cls): + if DSE_VERSION is None and SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + cls.simulacron_cluster = start_and_prime_cluster_defaults(dse_version=DSE_VERSION, + nodes_per_dc=cls.nodes_per_dc) + if cls.connect: + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cls.session = cls.cluster.connect(wait_for_all_pools=True) diff --git a/tests/integration/simulacron/advanced/__init__.py b/tests/integration/simulacron/advanced/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/tests/integration/simulacron/advanced/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/simulacron/advanced/test_insights.py b/tests/integration/simulacron/advanced/test_insights.py new file mode 100644 index 0000000000..07005a479b --- /dev/null +++ b/tests/integration/simulacron/advanced/test_insights.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import time +import json +import re + +from cassandra.cluster import Cluster +from cassandra.datastax.insights.util import version_supports_insights + +from tests.integration import requiressimulacron, requiredse, DSE_VERSION +from tests.integration.simulacron import DseSimulacronCluster, PROTOCOL_VERSION +from tests.integration.simulacron.utils import SimulacronClient, GetLogsQuery, ClearLogsQuery + + +@requiredse +@requiressimulacron +@unittest.skipUnless(DSE_VERSION and version_supports_insights(str(DSE_VERSION)), 'DSE {} does not support insights'.format(DSE_VERSION)) +class InsightsTests(DseSimulacronCluster): + """ + Tests insights integration + + @since 3.18 + @jira_ticket PYTHON-1047 + @expected_result startup and status messages are sent + """ + + connect = False + + def tearDown(self): + if self.cluster: + self.cluster.shutdown() + + @staticmethod + def _get_node_logs(raw_data): + return list(filter(lambda q: q['type'] == 'QUERY' and q['query'].startswith('CALL InsightsRpc.reportInsight'), + json.loads(raw_data)['data_centers'][0]['nodes'][0]['queries'])) + + @staticmethod + def _parse_data(data, index=0): + return json.loads(re.match( + r"CALL InsightsRpc.reportInsight\('(.+)'\)", + data[index]['frame']['message']['query']).group(1)) + + def test_startup_message(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + self.session = self.cluster.connect(wait_for_all_pools=True) + + time.sleep(1) # wait the monitor thread is started + response = SimulacronClient().submit_request(GetLogsQuery()) + self.assertTrue('CALL InsightsRpc.reportInsight' in response) + + node_queries = self._get_node_logs(response) + self.assertEqual(1, len(node_queries)) + self.assertTrue(node_queries, "RPC query not found") + + message = self._parse_data(node_queries) + + self.assertEqual(message['metadata']['name'], 'driver.startup') + self.assertEqual(message['data']['initialControlConnection'], + self.cluster.control_connection._connection.host) + self.assertEqual(message['data']['sessionId'], str(self.session.session_id)) + self.assertEqual(message['data']['clientId'], str(self.cluster.client_id)) + self.assertEqual(message['data']['compression'], 'NONE') + + def test_status_message(self): + SimulacronClient().submit_request(ClearLogsQuery()) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False, monitor_reporting_interval=1) + self.session = self.cluster.connect(wait_for_all_pools=True) + + time.sleep(1.1) + response = SimulacronClient().submit_request(GetLogsQuery()) + self.assertTrue('CALL InsightsRpc.reportInsight' in response) + + node_queries = self._get_node_logs(response) + self.assertEqual(2, len(node_queries)) + self.assertTrue(node_queries, "RPC query not found") + + message = self._parse_data(node_queries, 1) + + self.assertEqual(message['metadata']['name'], 'driver.status') + self.assertEqual(message['data']['controlConnection'], + self.cluster.control_connection._connection.host) + self.assertEqual(message['data']['sessionId'], str(self.session.session_id)) + self.assertEqual(message['data']['clientId'], str(self.cluster.client_id)) + self.assertEqual(message['metadata']['insightType'], 'EVENT') + + def test_monitor_disabled(self): + SimulacronClient().submit_request(ClearLogsQuery()) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False, monitor_reporting_enabled=False) + self.session = self.cluster.connect(wait_for_all_pools=True) + + response = SimulacronClient().submit_request(GetLogsQuery()) + self.assertFalse('CALL InsightsRpc.reportInsight' in response) diff --git a/tests/integration/simulacron/conftest.py b/tests/integration/simulacron/conftest.py new file mode 100644 index 0000000000..a4377996bb --- /dev/null +++ b/tests/integration/simulacron/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from tests.integration.simulacron import teardown_package + +@pytest.fixture(scope='session', autouse=True) +def setup_and_teardown_packages(): + print('setup') + yield + teardown_package() \ No newline at end of file diff --git a/tests/integration/simulacron/test_backpressure.py b/tests/integration/simulacron/test_backpressure.py new file mode 100644 index 0000000000..0418c05814 --- /dev/null +++ b/tests/integration/simulacron/test_backpressure.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from cassandra import OperationTimedOut +from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT, NoHostAvailable +from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy +from tests.integration import requiressimulacron, libevtest +from tests.integration.simulacron import SimulacronBase, PROTOCOL_VERSION +from tests.integration.simulacron.utils import ResumeReads, PauseReads, prime_request, start_and_prime_singledc + + +@requiressimulacron +@libevtest +class TCPBackpressureTests(SimulacronBase): + def setUp(self): + self.callback_successes = 0 + self.callback_errors = 0 + + def callback_success(self, results): + self.callback_successes += 1 + + def callback_error(self, results): + self.callback_errors += 1 + + def _fill_buffers(self, session, query, expected_blocked=3, **execute_kwargs): + futures = [] + buffer = '1' * 50000 + for _ in range(100000): + future = session.execute_async(query, [buffer], **execute_kwargs) + futures.append(future) + + total_blocked = 0 + for pool in session.get_pools(): + if not pool._connection._socket_writable: + total_blocked += 1 + if total_blocked >= expected_blocked: + break + else: + raise Exception("Unable to fill TCP send buffer on expected number of nodes") + return futures + + def test_paused_connections(self): + """ Verify all requests come back as expected if node resumes within query timeout """ + start_and_prime_singledc() + profile = ExecutionProfile(request_timeout=500, load_balancing_policy=RoundRobinPolicy()) + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + ) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + query = session.prepare("INSERT INTO table1 (id) VALUES (?)") + + prime_request(PauseReads()) + futures = self._fill_buffers(session, query) + + # Make sure we actually have some stuck in-flight requests + for in_flight in [pool._connection.in_flight for pool in session.get_pools()]: + self.assertGreater(in_flight, 100) + time.sleep(.5) + for in_flight in [pool._connection.in_flight for pool in session.get_pools()]: + self.assertGreater(in_flight, 100) + + prime_request(ResumeReads()) + + for future in futures: + try: + future.result() + except NoHostAvailable as e: + # We shouldn't have any timeouts here, but all of the queries beyond what can fit + # in the tcp buffer will have returned with a ConnectionBusy exception + self.assertIn("ConnectionBusy", str(e)) + + # Verify that we can continue sending queries without any problems + for host in session.cluster.metadata.all_hosts(): + session.execute(query, ["a"], host=host) + + def test_queued_requests_timeout(self): + """ Verify that queued requests timeout as expected """ + start_and_prime_singledc() + profile = ExecutionProfile(request_timeout=.1, load_balancing_policy=RoundRobinPolicy()) + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + ) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + query = session.prepare("INSERT INTO table1 (id) VALUES (?)") + + prime_request(PauseReads()) + + futures = [] + for i in range(1000): + future = session.execute_async(query, [str(i)]) + future.add_callbacks(callback=self.callback_success, errback=self.callback_error) + futures.append(future) + + successes = 0 + for future in futures: + try: + future.result() + successes += 1 + except OperationTimedOut: + pass + + # Simulacron will respond to a couple queries before cutting off reads, so we'll just verify + # that only "a few" successes happened here + self.assertLess(successes, 50) + self.assertLess(self.callback_successes, 50) + self.assertEqual(self.callback_errors, len(futures) - self.callback_successes) + + def test_cluster_busy(self): + """ Verify that once TCP buffer is full we get busy exceptions rather than timeouts """ + start_and_prime_singledc() + profile = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + ) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + query = session.prepare("INSERT INTO table1 (id) VALUES (?)") + + prime_request(PauseReads()) + + # These requests will get stuck in the TCP buffer and we have no choice but to let them time out + self._fill_buffers(session, query, expected_blocked=3) + + # Now that our send buffer is completely full, verify we immediately get busy exceptions rather than timing out + for i in range(1000): + with self.assertRaises(NoHostAvailable) as e: + session.execute(query, [str(i)]) + self.assertIn("ConnectionBusy", str(e.exception)) + + def test_node_busy(self): + """ Verify that once TCP buffer is full, queries continue to get re-routed to other nodes """ + start_and_prime_singledc() + profile = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + ) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + query = session.prepare("INSERT INTO table1 (id) VALUES (?)") + + prime_request(PauseReads(dc_id=0, node_id=0)) + + blocked_profile = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"])) + cluster.add_execution_profile('blocked_profile', blocked_profile) + + # Fill our blocked node's tcp buffer until we get a busy exception + self._fill_buffers(session, query, expected_blocked=1, execution_profile='blocked_profile') + + # Now that our send buffer is completely full on one node, + # verify queries get re-routed to other nodes and queries complete successfully + for i in range(1000): + session.execute(query, [str(i)]) + diff --git a/tests/integration/simulacron/test_cluster.py b/tests/integration/simulacron/test_cluster.py new file mode 100644 index 0000000000..2dfbc1f786 --- /dev/null +++ b/tests/integration/simulacron/test_cluster.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import logging +from packaging.version import Version + +import cassandra +from tests.integration.simulacron import SimulacronCluster, SimulacronBase +from tests.integration import (requiressimulacron, PROTOCOL_VERSION, DSE_VERSION, MockLoggingHandler) +from tests.integration.simulacron.utils import prime_query, start_and_prime_singledc + +from cassandra import (WriteTimeout, WriteType, + ConsistencyLevel, UnresolvableContactPoints) +from cassandra.cluster import Cluster, ControlConnection + + +PROTOCOL_VERSION = min(4, PROTOCOL_VERSION if (DSE_VERSION is None or DSE_VERSION >= Version('5.0')) else 3) + +@requiressimulacron +class ClusterTests(SimulacronCluster): + def test_writetimeout(self): + write_type = "UNLOGGED_BATCH" + consistency = "LOCAL_QUORUM" + received_responses = 1 + required_responses = 4 + + query_to_prime_simple = "SELECT * from simulacron_keyspace.simple" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": consistency, + "received": received_responses, + "block_for": required_responses, + "write_type": write_type, + "ignore_on_prepare": True + } + prime_query(query_to_prime_simple, then=then, rows=None, column_types=None) + + with self.assertRaises(WriteTimeout) as assert_raised_context: + self.session.execute(query_to_prime_simple) + wt = assert_raised_context.exception + self.assertEqual(wt.write_type, WriteType.name_to_value[write_type]) + self.assertEqual(wt.consistency, ConsistencyLevel.name_to_value[consistency]) + self.assertEqual(wt.received_responses, received_responses) + self.assertEqual(wt.required_responses, required_responses) + self.assertIn(write_type, str(wt)) + self.assertIn(consistency, str(wt)) + self.assertIn(str(received_responses), str(wt)) + self.assertIn(str(required_responses), str(wt)) + + +@requiressimulacron +class ClusterDNSResolutionTests(SimulacronCluster): + + connect = False + + def tearDown(self): + if self.cluster: + self.cluster.shutdown() + + def test_connection_with_one_unresolvable_contact_point(self): + # shouldn't raise anything due to name resolution failures + self.cluster = Cluster(['127.0.0.1', 'dns.invalid'], + protocol_version=PROTOCOL_VERSION, + compression=False) + + def test_connection_with_only_unresolvable_contact_points(self): + with self.assertRaises(UnresolvableContactPoints): + self.cluster = Cluster(['dns.invalid'], + protocol_version=PROTOCOL_VERSION, + compression=False) + + +@requiressimulacron +class DuplicateRpcTest(SimulacronCluster): + connect = False + + def test_duplicate(self): + mock_handler = MockLoggingHandler() + logger = logging.getLogger(cassandra.cluster.__name__) + logger.addHandler(mock_handler) + address_column = "native_transport_address" if DSE_VERSION and DSE_VERSION > Version("6.0") else "rpc_address" + rows = [ + {"peer": "127.0.0.1", "data_center": "dc", "host_id": "dontcare1", "rack": "rack1", + "release_version": "3.11.4", address_column: "127.0.0.1", "schema_version": "dontcare", "tokens": "1"}, + {"peer": "127.0.0.2", "data_center": "dc", "host_id": "dontcare2", "rack": "rack1", + "release_version": "3.11.4", address_column: "127.0.0.2", "schema_version": "dontcare", "tokens": "2"}, + ] + prime_query(ControlConnection._SELECT_PEERS, rows=rows) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect(wait_for_all_pools=True) + + warnings = mock_handler.messages.get("warning") + self.assertEqual(len(warnings), 1) + self.assertTrue('multiple hosts with the same endpoint' in warnings[0]) + logger.removeHandler(mock_handler) + cluster.shutdown() diff --git a/tests/integration/simulacron/test_connection.py b/tests/integration/simulacron/test_connection.py new file mode 100644 index 0000000000..de8060da2d --- /dev/null +++ b/tests/integration/simulacron/test_connection.py @@ -0,0 +1,511 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import logging +import time +from unittest.mock import Mock, patch + +from cassandra import OperationTimedOut +from cassandra.cluster import (EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, + _Scheduler, NoHostAvailable) +from cassandra.policies import HostStateListener, RoundRobinPolicy, WhiteListRoundRobinPolicy + +from tests import connection_class, thread_pool_executor_class +from tests.util import late +from tests.integration import requiressimulacron, libevtest +from tests.integration.util import assert_quiescent_pool_state +# important to import the patch PROTOCOL_VERSION from the simulacron module +from tests.integration.simulacron import SimulacronBase, PROTOCOL_VERSION +from cassandra.connection import DEFAULT_CQL_VERSION, Connection +from tests.unit.cython.utils import cythontest +from tests.integration.simulacron.utils import (NO_THEN, PrimeOptions, + prime_query, prime_request, + start_and_prime_cluster_defaults, + start_and_prime_singledc, + clear_queries, RejectConnections, + RejectType, AcceptConnections, PauseReads, ResumeReads) + + +class TrackDownListener(HostStateListener): + def __init__(self): + self.hosts_marked_down = [] + + def on_down(self, host): + self.hosts_marked_down.append(host) + + def on_up(self, host): + pass + + def on_add(self, host): + pass + + def on_remove(self, host): + pass + +class ThreadTracker(thread_pool_executor_class): + called_functions = [] + + def submit(self, fn, *args, **kwargs): + self.called_functions.append(fn.__name__) + return super(ThreadTracker, self).submit(fn, *args, **kwargs) + + +class OrderedRoundRobinPolicy(RoundRobinPolicy): + + def make_query_plan(self, working_keyspace=None, query=None): + self._position += 1 + + hosts = [] + for _ in range(10): + hosts.extend(sorted(self._live_hosts, key=lambda x : x.address)) + + return hosts + + +def _send_options_message(self): + """ + Mock that doesn't the OptionMessage. It is required for the heart_beat_timeout + test to avoid a condition where the CC tries to reconnect in the executor but can't + since we prime that message.""" + self._compressor = None + self.cql_version = DEFAULT_CQL_VERSION + self._send_startup_message(no_compact=self.no_compact) + + +@requiressimulacron +class ConnectionTests(SimulacronBase): + + @patch('cassandra.connection.Connection._send_options_message', _send_options_message) + def test_heart_beat_timeout(self): + """ + Test to ensure the hosts are marked as down after a OTO is received. + Also to ensure this happens within the expected timeout + @since 3.10 + @jira_ticket PYTHON-762 + @expected_result all the hosts have been marked as down at some point + + @test_category metadata + """ + number_of_dcs = 3 + nodes_per_dc = 20 + + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + + idle_heartbeat_timeout = 5 + idle_heartbeat_interval = 1 + + start_and_prime_cluster_defaults(number_of_dcs, nodes_per_dc) + + listener = TrackDownListener() + executor = ThreadTracker(max_workers=8) + + # We need to disable compression since it's not supported in simulacron + cluster = Cluster(compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout, + protocol_version=PROTOCOL_VERSION, + executor_threads=8, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=RoundRobinPolicy())}) + self.addCleanup(cluster.shutdown) + + cluster.scheduler.shutdown() + cluster.executor = executor + cluster.scheduler = _Scheduler(executor) + + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + log = logging.getLogger() + log.setLevel('CRITICAL') + self.addCleanup(log.setLevel, "DEBUG") + + prime_query(query_to_prime, then=NO_THEN) + + futures = [] + for _ in range(number_of_dcs * nodes_per_dc): + future = session.execute_async(query_to_prime) + futures.append(future) + + for f in futures: + f._event.wait() + self.assertIsInstance(f._final_exception, OperationTimedOut) + + prime_request(PrimeOptions(then=NO_THEN)) + + # We allow from some extra time for all the hosts to be to on_down + # The callbacks should start happening after idle_heartbeat_timeout + idle_heartbeat_interval + time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2.5) + + for host in cluster.metadata.all_hosts(): + self.assertIn(host, listener.hosts_marked_down) + + # In this case HostConnection._replace shouldn't be called + self.assertNotIn("_replace", executor.called_functions) + + def test_callbacks_and_pool_when_oto(self): + """ + Test to ensure the callbacks are correcltly called and the connection + is returned when there is an OTO + @since 3.12 + @jira_ticket PYTHON-630 + @expected_result the connection is correctly returned to the pool + after an OTO, also the only the errback is called and not the callback + when the message finally arrives. + + @test_category metadata + """ + start_and_prime_singledc() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + query_to_prime = "SELECT * from testkesypace.testtable" + + server_delay = 2 # seconds + prime_query(query_to_prime, then={"delay_in_ms": server_delay * 1000}) + + future = session.execute_async(query_to_prime, timeout=1) + callback, errback = Mock(name='callback'), Mock(name='errback') + future.add_callbacks(callback, errback) + self.assertRaises(OperationTimedOut, future.result) + + assert_quiescent_pool_state(self, cluster) + + time.sleep(server_delay + 1) + # PYTHON-630 -- only the errback should be called + errback.assert_called_once() + callback.assert_not_called() + + @cythontest + @libevtest + def test_heartbeat_defunct_deadlock(self): + """ + Ensure that there is no deadlock when request is in-flight and heartbeat defuncts connection + @since 3.16 + @jira_ticket PYTHON-1044 + @expected_result an OperationTimeout is raised and no deadlock occurs + + @test_category connection + """ + start_and_prime_singledc() + + # This is all about timing. We will need the QUERY response future to time out and the heartbeat to defunct + # at the same moment. The latter will schedule a QUERY retry to another node in case the pool is not + # already shut down. If and only if the response future timeout falls in between the retry scheduling and + # its execution the deadlock occurs. The odds are low, so we need to help fate a bit: + # 1) Make one heartbeat messages be sent to every node + # 2) Our QUERY goes always to the same host + # 3) This host needs to defunct first + # 4) Open a small time window for the response future timeout, i.e. block executor threads for retry + # execution and last connection to defunct + query_to_prime = "SELECT * from testkesypace.testtable" + query_host = "127.0.0.2" + heartbeat_interval = 1 + heartbeat_timeout = 1 + lag = 0.05 + never = 9999 + + class PatchedRoundRobinPolicy(RoundRobinPolicy): + # Send always to same host + def make_query_plan(self, working_keyspace=None, query=None): + if query and query.query_string == query_to_prime: + return filter(lambda h: h == query_host, self._live_hosts) + else: + return super(PatchedRoundRobinPolicy, self).make_query_plan() + + class PatchedCluster(Cluster): + # Make sure that QUERY connection will timeout first + def get_connection_holders(self): + holders = super(PatchedCluster, self).get_connection_holders() + return sorted(holders, reverse=True, key=lambda v: int(v._connection.host == query_host)) + + # Block executor thread like closing a dead socket could do + def connection_factory(self, *args, **kwargs): + conn = super(PatchedCluster, self).connection_factory(*args, **kwargs) + conn.defunct = late(seconds=2*lag)(conn.defunct) + return conn + + cluster = PatchedCluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + idle_heartbeat_interval=heartbeat_interval, + idle_heartbeat_timeout=heartbeat_timeout, + load_balancing_policy=PatchedRoundRobinPolicy() + ) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + prime_query(query_to_prime, then={"delay_in_ms": never}) + + # Make heartbeat due + time.sleep(heartbeat_interval) + + future = session.execute_async(query_to_prime, timeout=heartbeat_interval+heartbeat_timeout+3*lag) + # Delay thread execution like kernel could do + future._retry_task = late(seconds=4*lag)(future._retry_task) + + prime_request(PrimeOptions(then={"result": "no_result", "delay_in_ms": never})) + prime_request(RejectConnections("unbind")) + + self.assertRaisesRegex(OperationTimedOut, "Connection defunct by heartbeat", future.result) + + def test_close_when_query(self): + """ + Test to ensure the driver behaves correctly if the connection is closed + just when querying + @since 3.12 + @expected_result NoHostAvailable is risen + + @test_category connection + """ + start_and_prime_singledc() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + query_to_prime = "SELECT * from testkesypace.testtable" + + for close_type in ("disconnect", "shutdown_read", "shutdown_write"): + then = { + "result": "close_connection", + "delay_in_ms": 0, + "close_type": close_type, + "scope": "connection" + } + + prime_query(query_to_prime, rows=None, column_types=None, then=then) + self.assertRaises(NoHostAvailable, session.execute, query_to_prime) + + def test_retry_after_defunct(self): + """ + We test cluster._retry is called if an the connection is defunct + in the middle of a query + + Finally we verify the driver recovers correctly in the event + of a network partition + + @since 3.12 + @expected_result the driver is able to query even if a host is marked + as down in the middle of the query, it will go to the next one if the timeout + hasn't expired + + @test_category connection + """ + number_of_dcs = 3 + nodes_per_dc = 2 + + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + + idle_heartbeat_timeout = 1 + idle_heartbeat_interval = 5 + + simulacron_cluster = start_and_prime_cluster_defaults(number_of_dcs, nodes_per_dc) + + dc_ids = sorted(simulacron_cluster.data_center_ids) + last_host = dc_ids.pop() + prime_query(query_to_prime, + cluster_name="{}/{}".format(simulacron_cluster.cluster_name, last_host)) + + roundrobin_lbp = OrderedRoundRobinPolicy() + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=roundrobin_lbp)}) + + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + # This simulates we only have access to one DC + for dc_id in dc_ids: + datacenter_path = "{}/{}".format(simulacron_cluster.cluster_name, dc_id) + prime_query(query_to_prime, then=NO_THEN, cluster_name=datacenter_path) + prime_request(PrimeOptions(then=NO_THEN, cluster_name=datacenter_path)) + + # Only the last datacenter will respond, therefore the first host won't + # We want to make sure the returned hosts are 127.0.0.1, 127.0.0.2, ... 127.0.0.8 + roundrobin_lbp._position = 0 + + # After 3 + 1 seconds the connection should be marked and down and another host retried + response_future = session.execute_async(query_to_prime, timeout=4 * idle_heartbeat_interval + + idle_heartbeat_timeout) + response_future.result() + self.assertGreater(len(response_future.attempted_hosts), 1) + + # No error should be raised here since the hosts have been marked + # as down and there's still 1 DC available + for _ in range(10): + session.execute(query_to_prime) + + # Might take some time to close the previous connections and reconnect + time.sleep(10) + assert_quiescent_pool_state(self, cluster) + clear_queries() + + time.sleep(10) + assert_quiescent_pool_state(self, cluster) + + def test_idle_connection_is_not_closed(self): + """ + Test to ensure that the connections aren't closed if they are idle + @since 3.12 + @jira_ticket PYTHON-573 + @expected_result the connections aren't closed nor the hosts are + set to down if the connection is idle + + @test_category connection + """ + start_and_prime_singledc() + + idle_heartbeat_timeout = 1 + idle_heartbeat_interval = 1 + + listener = TrackDownListener() + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False, + idle_heartbeat_interval=idle_heartbeat_interval, + idle_heartbeat_timeout=idle_heartbeat_timeout) + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + self.addCleanup(cluster.shutdown) + + time.sleep(20) + + self.assertEqual(listener.hosts_marked_down, []) + + def test_host_is_not_set_to_down_after_query_oto(self): + """ + Test to ensure that the connections aren't closed if there's an + OperationTimedOut in a normal query. This should only happen from the + heart beat thread (in the case of a OperationTimedOut) with the default + configuration + @since 3.12 + @expected_result the connections aren't closed nor the hosts are + set to down + + @test_category connection + """ + start_and_prime_singledc() + + query_to_prime = "SELECT * FROM madeup_keyspace.madeup_table" + + prime_query(query_to_prime, then=NO_THEN) + + listener = TrackDownListener() + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + session = cluster.connect(wait_for_all_pools=True) + cluster.register_listener(listener) + + futures = [] + for _ in range(10): + future = session.execute_async(query_to_prime) + futures.append(future) + + for f in futures: + f._event.wait() + self.assertIsInstance(f._final_exception, OperationTimedOut) + + self.assertEqual(listener.hosts_marked_down, []) + assert_quiescent_pool_state(self, cluster) + + def test_can_shutdown_connection_subclass(self): + start_and_prime_singledc() + class ExtendedConnection(connection_class): + pass + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, + contact_points=["127.0.0.2"], + connection_class=ExtendedConnection, + compression=False) + cluster.connect() + cluster.shutdown() + + def test_driver_recovers_nework_isolation(self): + start_and_prime_singledc() + + idle_heartbeat_timeout = 3 + idle_heartbeat_interval = 1 + + listener = TrackDownListener() + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, contact_points=['127.0.0.1'], + idle_heartbeat_timeout=idle_heartbeat_timeout, + idle_heartbeat_interval=idle_heartbeat_interval, + executor_threads=16, + compression=False, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + }) + session = cluster.connect(wait_for_all_pools=True) + + cluster.register_listener(listener) + + prime_request(PrimeOptions(then=NO_THEN)) + prime_request(RejectConnections(RejectType.REJECT_STARTUP)) + + time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2) + + for host in cluster.metadata.all_hosts(): + self.assertIn(host, listener.hosts_marked_down) + + self.assertRaises(NoHostAvailable, session.execute, "SELECT * from system.local") + + clear_queries() + prime_request(AcceptConnections()) + + time.sleep(idle_heartbeat_timeout + idle_heartbeat_interval + 2) + + self.assertIsNotNone(session.execute("SELECT * from system.local")) + + def test_max_in_flight(self): + """ Verify we don't exceed max_in_flight when borrowing connections or sending heartbeats """ + Connection.max_in_flight = 50 + start_and_prime_singledc() + profile = ExecutionProfile(request_timeout=1, load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + idle_heartbeat_interval=.1, + idle_heartbeat_timeout=.1, + ) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + query = session.prepare("INSERT INTO table1 (id) VALUES (?)") + + prime_request(PauseReads()) + + futures = [] + # + 50 because simulacron doesn't immediately block all queries + for i in range(Connection.max_in_flight + 50): + futures.append(session.execute_async(query, ['a'])) + + prime_request(ResumeReads()) + + for future in futures: + # We're veryfing we don't get an assertion error from Connection.get_request_id, + # so skip any valid errors + try: + future.result() + except OperationTimedOut: + pass + except NoHostAvailable: + pass diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py new file mode 100644 index 0000000000..38d4c0f2a9 --- /dev/null +++ b/tests/integration/simulacron/test_empty_column.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from collections import namedtuple, OrderedDict + +from cassandra import ProtocolVersion +from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT +from cassandra.query import (named_tuple_factory, tuple_factory, + dict_factory, ordered_dict_factory) + +from cassandra.cqlengine import columns +from cassandra.cqlengine.connection import set_session +from cassandra.cqlengine.models import Model + +from tests.integration import requiressimulacron +from tests.integration.simulacron import PROTOCOL_VERSION, SimulacronCluster +from tests.integration.simulacron.utils import PrimeQuery, prime_request + + +PROTOCOL_VERSION = 4 if PROTOCOL_VERSION in \ + (ProtocolVersion.DSE_V1, ProtocolVersion.DSE_V2) else PROTOCOL_VERSION + + +@requiressimulacron +class EmptyColumnTests(SimulacronCluster): + """ + Test that legacy empty column names can be read by the driver. + + @since 3.18 + @jira_ticket PYTHON-1082 + @expected_result the driver supports those columns + """ + connect = False + + def tearDown(self): + if self.cluster: + self.cluster.shutdown() + + @staticmethod + def _prime_testtable_query(): + queries = [ + 'SELECT "", " " FROM testks.testtable', + 'SELECT "", " " FROM testks.testtable LIMIT 10000' # cqlengine + ] + then = { + 'result': 'success', + 'delay_in_ms': 0, + 'rows': [ + { + "": "testval", + " ": "testval1" + } + ], + 'column_types': { + "": "ascii", + " ": "ascii" + }, + 'ignore_on_prepare': False + } + for query in queries: + prime_request(PrimeQuery(query, then=then)) + + def test_empty_columns_with_all_row_factories(self): + query = 'SELECT "", " " FROM testks.testtable' + self._prime_testtable_query() + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + self.session = self.cluster.connect(wait_for_all_pools=True) + + # Test all row factories + self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = named_tuple_factory + self.assertEqual( + list(self.session.execute(query)), + [namedtuple('Row', ['field_0_', 'field_1_'])('testval', 'testval1')] + ) + + self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = tuple_factory + self.assertEqual( + list(self.session.execute(query)), + [('testval', 'testval1')] + ) + + self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = dict_factory + self.assertEqual( + list(self.session.execute(query)), + [{'': 'testval', ' ': 'testval1'}] + ) + + self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = ordered_dict_factory + self.assertEqual( + list(self.session.execute(query)), + [OrderedDict((('', 'testval'), (' ', 'testval1')))] + ) + + def test_empty_columns_in_system_schema(self): + queries = [ + "SELECT * FROM system_schema.tables", + "SELECT * FROM system.schema.tables", + "SELECT * FROM system.schema_columnfamilies" + ] + then = { + 'result': 'success', + 'delay_in_ms': 0, + 'rows': [ + { + "compression": dict(), + "compaction": dict(), + "bloom_filter_fp_chance": 0.1, + "caching": {"keys": "ALL", "rows_per_partition": "NONE"}, + "comment": "comment", + "gc_grace_seconds": 60000, + "keyspace_name": "testks", + "table_name": "testtable", + "columnfamily_name": "testtable", # C* 2.2 + "flags": ["compound"], + "comparator": "none" # C* 2.2 + } + ], + 'column_types': { + "compression": "map", + "compaction": "map", + "bloom_filter_fp_chance": "double", + "caching": "map", + "comment": "ascii", + "gc_grace_seconds": "int", + "keyspace_name": "ascii", + "table_name": "ascii", + "columnfamily_name": "ascii", + "flags": "set", + "comparator": "ascii" + }, + 'ignore_on_prepare': False + } + for query in queries: + query = PrimeQuery(query, then=then) + prime_request(query) + + queries = [ + "SELECT * FROM system_schema.keyspaces", + "SELECT * FROM system.schema_keyspaces" + ] + then = { + 'result': 'success', + 'delay_in_ms': 0, + 'rows': [ + { + "strategy_class": "SimpleStrategy", # C* 2.2 + "strategy_options": '{}', # C* 2.2 + "replication": {'strategy': 'SimpleStrategy', 'replication_factor': 1}, + "durable_writes": True, + "keyspace_name": "testks" + } + ], + 'column_types': { + "strategy_class": "ascii", + "strategy_options": "ascii", + "replication": "map", + "keyspace_name": "ascii", + "durable_writes": "boolean" + }, + 'ignore_on_prepare': False + } + for query in queries: + query = PrimeQuery(query, then=then) + prime_request(query) + + queries = [ + "SELECT * FROM system_schema.columns", + "SELECT * FROM system.schema.columns", + "SELECT * FROM system.schema_columns" + ] + then = { + 'result': 'success', + 'delay_in_ms': 0, + 'rows': [ + { + "table_name": 'testtable', + "columnfamily_name": 'testtable', # C* 2.2 + "column_name": "", + "keyspace_name": "testks", + "kind": "partition_key", + "clustering_order": "none", + "position": 0, + "type": "text", + "column_name_bytes": 0x12, + "validator": "none" # C* 2.2 + }, + { + "table_name": 'testtable', + "columnfamily_name": 'testtable', # C* 2.2 + "column_name": " ", + "keyspace_name": "testks", + "kind": "regular", + "clustering_order": "none", + "position": -1, + "type": "text", + "column_name_bytes": 0x13, + "validator": "none" # C* 2.2 + } + ], + 'column_types': { + "table_name": "ascii", + "columnfamily_name": "ascii", + "column_name": "ascii", + "keyspace_name": "ascii", + "clustering_order": "ascii", + "column_name_bytes": "blob", + "kind": "ascii", + "position": "int", + "type": "ascii", + "validator": "ascii" # C* 2.2 + }, + 'ignore_on_prepare': False + } + for query in queries: + query = PrimeQuery(query, then=then) + prime_request(query) + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + self.session = self.cluster.connect(wait_for_all_pools=True) + + table_metadata = self.cluster.metadata.keyspaces['testks'].tables['testtable'] + self.assertEqual(len(table_metadata.columns), 2) + self.assertIn('', table_metadata.columns) + self.assertIn(' ', table_metadata.columns) + + def test_empty_columns_with_cqlengine(self): + self._prime_testtable_query() + + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + self.session = self.cluster.connect(wait_for_all_pools=True) + set_session(self.session) + + class TestModel(Model): + __keyspace__ = 'testks' + __table_name__ = 'testtable' + empty = columns.Text(db_field='', primary_key=True) + space = columns.Text(db_field=' ') + + self.assertEqual( + [TestModel(empty='testval', space='testval1')], + list(TestModel.objects.only(['empty', 'space']).all()) + ) diff --git a/tests/integration/simulacron/test_endpoint.py b/tests/integration/simulacron/test_endpoint.py new file mode 100644 index 0000000000..6ab190091d --- /dev/null +++ b/tests/integration/simulacron/test_endpoint.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from functools import total_ordering + +from cassandra.cluster import Cluster +from cassandra.connection import DefaultEndPoint, EndPoint, DefaultEndPointFactory +from cassandra.metadata import _NodeInfo +from tests.integration import requiressimulacron +from tests.integration.simulacron import SimulacronCluster, PROTOCOL_VERSION + + +@total_ordering +class AddressEndPoint(EndPoint): + + def __init__(self, address, port=9042): + self._address = address + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port # connection purpose + + def __eq__(self, other): + return isinstance(other, AddressEndPoint) and \ + self.address == other.address + + def __hash__(self): + return hash(self.address) + + def __lt__(self, other): + return self.address < other.address + + def __str__(self): + return str("%s" % self.address) + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self.address) + + +class AddressEndPointFactory(DefaultEndPointFactory): + + def create(self, row): + addr = _NodeInfo.get_broadcast_rpc_address(row) + return AddressEndPoint(addr) + + +@requiressimulacron +class EndPointTests(SimulacronCluster): + """ + Basic tests to validate the internal use of the EndPoint class. + + @since 3.18 + @jira_ticket PYTHON-1079 + @expected_result all the hosts are using the proper endpoint class + """ + + def test_default_endpoint(self): + hosts = self.cluster.metadata.all_hosts() + self.assertEqual(len(hosts), 3) + for host in hosts: + self.assertIsNotNone(host.endpoint) + self.assertIsInstance(host.endpoint, DefaultEndPoint) + self.assertEqual(host.address, host.endpoint.address) + self.assertEqual(host.broadcast_rpc_address, host.endpoint.address) + + self.assertIsInstance(self.cluster.control_connection._connection.endpoint, DefaultEndPoint) + self.assertIsNotNone(self.cluster.control_connection._connection.endpoint) + endpoints = [host.endpoint for host in hosts] + self.assertIn(self.cluster.control_connection._connection.endpoint, endpoints) + + def test_custom_endpoint(self): + cluster = Cluster( + contact_points=[AddressEndPoint('127.0.0.1')], + protocol_version=PROTOCOL_VERSION, + endpoint_factory=AddressEndPointFactory(), + compression=False, + ) + cluster.connect(wait_for_all_pools=True) + + hosts = cluster.metadata.all_hosts() + self.assertEqual(len(hosts), 3) + for host in hosts: + self.assertIsNotNone(host.endpoint) + self.assertIsInstance(host.endpoint, AddressEndPoint) + self.assertEqual(str(host.endpoint), host.endpoint.address) + self.assertEqual(host.address, host.endpoint.address) + self.assertEqual(host.broadcast_rpc_address, host.endpoint.address) + + self.assertIsInstance(cluster.control_connection._connection.endpoint, AddressEndPoint) + self.assertIsNotNone(cluster.control_connection._connection.endpoint) + endpoints = [host.endpoint for host in hosts] + self.assertIn(cluster.control_connection._connection.endpoint, endpoints) + + cluster.shutdown() diff --git a/tests/integration/simulacron/test_policies.py b/tests/integration/simulacron/test_policies.py new file mode 100644 index 0000000000..a41fd54c59 --- /dev/null +++ b/tests/integration/simulacron/test_policies.py @@ -0,0 +1,463 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from cassandra import OperationTimedOut, WriteTimeout +from cassandra.cluster import Cluster, ExecutionProfile, ResponseFuture, EXEC_PROFILE_DEFAULT, NoHostAvailable +from cassandra.query import SimpleStatement +from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy, RetryPolicy, WriteType +from cassandra.protocol import OverloadedErrorMessage, IsBootstrappingErrorMessage, TruncateError, ServerError + +from tests.integration import greaterthancass21, requiressimulacron, SIMULACRON_JAR, \ + CASSANDRA_VERSION +from tests.integration.simulacron import PROTOCOL_VERSION +from tests.integration.simulacron.utils import start_and_prime_singledc, prime_query, \ + stop_simulacron, NO_THEN, clear_queries + +from itertools import count +from packaging.version import Version + + +class BadRoundRobinPolicy(RoundRobinPolicy): + def make_query_plan(self, working_keyspace=None, query=None): + pos = self._position + self._position += 1 + + hosts = [] + for _ in range(10): + hosts.extend(self._live_hosts) + + return hosts + + +# This doesn't work well with Windows clock granularity +@requiressimulacron +class SpecExecTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + start_and_prime_singledc() + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cls.session = cls.cluster.connect(wait_for_all_pools=True) + + spec_ep_brr = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(1, 6), + request_timeout=12) + spec_ep_rr = ExecutionProfile(speculative_execution_policy=ConstantSpeculativeExecutionPolicy(.5, 10), + request_timeout=12) + spec_ep_rr_lim = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(0.5, 1), + request_timeout=12) + spec_ep_brr_lim = ExecutionProfile(load_balancing_policy=BadRoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(4, 10)) + + cls.cluster.add_execution_profile("spec_ep_brr", spec_ep_brr) + cls.cluster.add_execution_profile("spec_ep_rr", spec_ep_rr) + cls.cluster.add_execution_profile("spec_ep_rr_lim", spec_ep_rr_lim) + cls.cluster.add_execution_profile("spec_ep_brr_lim", spec_ep_brr_lim) + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + + cls.cluster.shutdown() + stop_simulacron() + + def tearDown(self): + clear_queries() + + @greaterthancass21 + def test_speculative_execution(self): + """ + Test to ensure that speculative execution honors LBP, and that they retry appropriately. + + This test will use various LBP, and ConstantSpeculativeExecutionPolicy settings and ensure the proper number of hosts are queried + @since 3.7.0 + @jira_ticket PYTHON-218 + @expected_result speculative retries should honor max retries, idempotent state of queries, and underlying lbp. + + @test_category metadata + """ + query_to_prime = "INSERT INTO test3rf.test (k, v) VALUES (0, 1);" + prime_query(query_to_prime, then={"delay_in_ms": 10000}) + + statement = SimpleStatement(query_to_prime, is_idempotent=True) + statement_non_idem = SimpleStatement(query_to_prime, is_idempotent=False) + + # This LBP should repeat hosts up to around 30 + result = self.session.execute(statement, execution_profile='spec_ep_brr') + self.assertEqual(7, len(result.response_future.attempted_hosts)) + + # This LBP should keep host list to 3 + result = self.session.execute(statement, execution_profile='spec_ep_rr') + self.assertEqual(3, len(result.response_future.attempted_hosts)) + # Spec_execution policy should limit retries to 1 + result = self.session.execute(statement, execution_profile='spec_ep_rr_lim') + + self.assertEqual(2, len(result.response_future.attempted_hosts)) + + # Spec_execution policy should not be used if the query is not idempotent + result = self.session.execute(statement_non_idem, execution_profile='spec_ep_brr') + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Default policy with non_idem query + result = self.session.execute(statement_non_idem, timeout=12) + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Should be able to run an idempotent query against default execution policy with no speculative_execution_policy + result = self.session.execute(statement, timeout=12) + self.assertEqual(1, len(result.response_future.attempted_hosts)) + + # Test timeout with spec_ex + with self.assertRaises(OperationTimedOut): + self.session.execute(statement, execution_profile='spec_ep_rr', timeout=.5) + + prepared_query_to_prime = "SELECT * FROM test3rf.test where k = ?" + when = {"params": {"k": "0"}, "param_types": {"k": "ascii"}} + prime_query(prepared_query_to_prime, when=when, then={"delay_in_ms": 4000}) + + # PYTHON-736 Test speculation policy works with a prepared statement + prepared_statement = self.session.prepare(prepared_query_to_prime) + # non-idempotent + result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') + self.assertEqual(1, len(result.response_future.attempted_hosts)) + # idempotent + prepared_statement.is_idempotent = True + result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') + self.assertLess(1, len(result.response_future.attempted_hosts)) + + def test_speculative_and_timeout(self): + """ + Test to ensure the timeout is honored when using speculative execution + @since 3.10 + @jira_ticket PYTHON-750 + @expected_result speculative retries be schedule every fixed period, during the maximum + period of the timeout. + + @test_category metadata + """ + query_to_prime = "INSERT INTO testkeyspace.testtable (k, v) VALUES (0, 1);" + prime_query(query_to_prime, then=NO_THEN) + + statement = SimpleStatement(query_to_prime, is_idempotent=True) + + # An OperationTimedOut is placed here in response_future, + # that's why we can't call session.execute,which would raise it, but + # we have to directly wait for the event + response_future = self.session.execute_async(statement, execution_profile='spec_ep_brr_lim', + timeout=14) + response_future._event.wait(16) + self.assertIsInstance(response_future._final_exception, OperationTimedOut) + + # This is because 14 / 4 + 1 = 4 + self.assertEqual(len(response_future.attempted_hosts), 4) + + def test_delay_can_be_0(self): + """ + Test to validate that the delay can be zero for the ConstantSpeculativeExecutionPolicy + @since 3.13 + @jira_ticket PYTHON-836 + @expected_result all the queries are executed immediately + @test_category policy + """ + query_to_prime = "INSERT INTO madeup_keyspace.madeup_table(k, v) VALUES (1, 2)" + prime_query(query_to_prime, then={"delay_in_ms": 5000}) + number_of_requests = 4 + spec = ExecutionProfile(load_balancing_policy=RoundRobinPolicy(), + speculative_execution_policy=ConstantSpeculativeExecutionPolicy(0, number_of_requests)) + + cluster = Cluster(protocol_version=PROTOCOL_VERSION, compression=False) + cluster.add_execution_profile("spec", spec) + session = cluster.connect(wait_for_all_pools=True) + self.addCleanup(cluster.shutdown) + + counter = count() + + def patch_and_count(f): + def patched(*args, **kwargs): + next(counter) + f(*args, **kwargs) + return patched + + self.addCleanup(setattr, ResponseFuture, "send_request", ResponseFuture.send_request) + ResponseFuture.send_request = patch_and_count(ResponseFuture.send_request) + stmt = SimpleStatement(query_to_prime) + stmt.is_idempotent = True + results = session.execute(stmt, execution_profile="spec") + self.assertEqual(len(results.response_future.attempted_hosts), 3) + + # send_request is called number_of_requests times for the speculative request + # plus one for the call from the main thread. + self.assertEqual(next(counter), number_of_requests + 1) + + +class CustomRetryPolicy(RetryPolicy): + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + if retry_num != 0: + return self.RETHROW, None + elif write_type == WriteType.SIMPLE: + return self.RETHROW, None + elif write_type == WriteType.CDC: + return self.IGNORE, None + + +class CounterRetryPolicy(RetryPolicy): + def __init__(self): + self.write_timeout = count() + self.read_timeout = count() + self.unavailable = count() + self.request_error = count() + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + next(self.read_timeout) + return self.IGNORE, None + + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + next(self.write_timeout) + return self.IGNORE, None + + def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): + next(self.unavailable) + return self.IGNORE, None + + def on_request_error(self, query, consistency, error, retry_num): + next(self.request_error) + return self.RETHROW, None + + def reset_counters(self): + self.write_timeout = count() + self.read_timeout = count() + self.unavailable = count() + self.request_error = count() + + +@requiressimulacron +class RetryPolicyTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + start_and_prime_singledc() + + @classmethod + def tearDownClass(cls): + if SIMULACRON_JAR is None or CASSANDRA_VERSION < Version("2.1"): + return + stop_simulacron() + + def tearDown(self): + clear_queries() + + def set_cluster(self, retry_policy): + self.cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + compression=False, + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(retry_policy=retry_policy) + }, + ) + self.session = self.cluster.connect(wait_for_all_pools=True) + self.addCleanup(self.cluster.shutdown) + + def test_retry_policy_ignores_and_rethrows(self): + """ + Test to verify :class:`~cassandra.protocol.WriteTimeoutErrorMessage` is decoded correctly and that + :attr:`.~cassandra.policies.RetryPolicy.RETHROW` and + :attr:`.~cassandra.policies.RetryPolicy.IGNORE` are respected + to localhost + + @since 3.12 + @jira_ticket PYTHON-812 + @expected_result the retry policy functions as expected + + @test_category connection + """ + self.set_cluster(CustomRetryPolicy()) + query_to_prime_simple = "SELECT * from simulacron_keyspace.simple" + query_to_prime_cdc = "SELECT * from simulacron_keyspace.cdc" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + prime_query(query_to_prime_simple, rows=None, column_types=None, then=then) + then["write_type"] = "CDC" + prime_query(query_to_prime_cdc, rows=None, column_types=None, then=then) + + with self.assertRaises(WriteTimeout): + self.session.execute(query_to_prime_simple) + + #CDC should be ignored + self.session.execute(query_to_prime_cdc) + + def test_retry_policy_with_prepared(self): + """ + Test to verify that the retry policy is called as expected + for bound and prepared statements when set at the cluster level + + @since 3.13 + @jira_ticket PYTHON-861 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + counter_policy = CounterRetryPolicy() + self.set_cluster(counter_policy) + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table" + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + prime_query(query_to_prime, then=then, rows=None, column_types=None) + self.session.execute(query_to_prime) + self.assertEqual(next(counter_policy.write_timeout), 1) + counter_policy.reset_counters() + + query_to_prime_prepared = "SELECT * from simulacron_keyspace.simulacron_table WHERE key = :key" + when = {"params": {"key": "0"}, "param_types": {"key": "ascii"}} + + prime_query(query_to_prime_prepared, when=when, then=then, rows=None, column_types=None) + + prepared_stmt = self.session.prepare(query_to_prime_prepared) + + bound_stm = prepared_stmt.bind({"key": "0"}) + self.session.execute(bound_stm) + self.assertEqual(next(counter_policy.write_timeout), 1) + + counter_policy.reset_counters() + self.session.execute(prepared_stmt, ("0",)) + self.assertEqual(next(counter_policy.write_timeout), 1) + + def test_setting_retry_policy_to_statement(self): + """ + Test to verify that the retry policy is called as expected + for bound and prepared statements when set to the prepared statement + + @since 3.13 + @jira_ticket PYTHON-861 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + retry_policy = RetryPolicy() + self.set_cluster(retry_policy) + then = { + "result": "write_timeout", + "delay_in_ms": 0, + "consistency_level": "LOCAL_QUORUM", + "received": 1, + "block_for": 2, + "write_type": "SIMPLE", + "ignore_on_prepare": True + } + query_to_prime_prepared = "SELECT * from simulacron_keyspace.simulacron_table WHERE key = :key" + when = {"params": {"key": "0"}, "param_types": {"key": "ascii"}} + prime_query(query_to_prime_prepared, when=when, then=then, rows=None, column_types=None) + + counter_policy = CounterRetryPolicy() + prepared_stmt = self.session.prepare(query_to_prime_prepared) + prepared_stmt.retry_policy = counter_policy + self.session.execute(prepared_stmt, ("0",)) + self.assertEqual(next(counter_policy.write_timeout), 1) + + counter_policy.reset_counters() + bound_stmt = prepared_stmt.bind({"key": "0"}) + bound_stmt.retry_policy = counter_policy + self.session.execute(bound_stmt) + self.assertEqual(next(counter_policy.write_timeout), 1) + + def test_retry_policy_on_request_error(self): + """ + Test to verify that on_request_error is called properly. + + @since 3.18 + @jira_ticket PYTHON-1064 + @expected_result the appropriate retry policy is called + + @test_category connection + """ + overloaded_error = { + "result": "overloaded", + "message": "overloaded" + } + + bootstrapping_error = { + "result": "is_bootstrapping", + "message": "isbootstrapping" + } + + truncate_error = { + "result": "truncate_error", + "message": "truncate_error" + } + + server_error = { + "result": "server_error", + "message": "server_error" + } + + # Test the on_request_error call + retry_policy = CounterRetryPolicy() + self.set_cluster(retry_policy) + + for prime_error, exc in [ + (overloaded_error, OverloadedErrorMessage), + (bootstrapping_error, IsBootstrappingErrorMessage), + (truncate_error, TruncateError), + (server_error, ServerError)]: + + clear_queries() + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table;" + prime_query(query_to_prime, then=prime_error, rows=None, column_types=None) + rf = self.session.execute_async(query_to_prime) + + with self.assertRaises(exc): + rf.result() + + self.assertEqual(len(rf.attempted_hosts), 1) # no retry + + self.assertEqual(next(retry_policy.request_error), 4) + + # Test that by default, retry on next host + retry_policy = RetryPolicy() + self.set_cluster(retry_policy) + + for e in [overloaded_error, bootstrapping_error, truncate_error, server_error]: + clear_queries() + query_to_prime = "SELECT * from simulacron_keyspace.simulacron_table;" + prime_query(query_to_prime, then=e, rows=None, column_types=None) + rf = self.session.execute_async(query_to_prime) + + with self.assertRaises(NoHostAvailable): + rf.result() + + self.assertEqual(len(rf.attempted_hosts), 3) # all 3 nodes failed diff --git a/tests/integration/simulacron/utils.py b/tests/integration/simulacron/utils.py new file mode 100644 index 0000000000..01d94fc539 --- /dev/null +++ b/tests/integration/simulacron/utils.py @@ -0,0 +1,466 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import json +import subprocess +import time +from urllib.request import build_opener, Request, HTTPHandler + +from cassandra.metadata import SchemaParserV4, SchemaParserDSE68 + +from tests.util import wait_until_not_raised +from tests.integration import CASSANDRA_VERSION, SIMULACRON_JAR, DSE_VERSION + +DEFAULT_CLUSTER = "python_simulacron_cluster" + + +class SimulacronCluster(object): + """ + Represents a Cluster object as returned by simulacron + """ + def __init__(self, json_text): + self.json = json_text + self.o = json.loads(json_text) + + @property + def cluster_id(self): + return self.o["id"] + + @property + def cluster_name(self): + return self.o["name"] + + @property + def data_center_ids(self): + return [dc["id"] for dc in self.o["data_centers"]] + + @property + def data_centers_names(self): + return [dc["name"] for dc in self.o["data_centers"]] + + def get_node_ids(self, datacenter_id): + datacenter = list(filter(lambda x: x["id"] == datacenter_id, self.o["data_centers"])).pop() + return [node["id"] for node in datacenter["nodes"]] + + +class SimulacronServer(object): + """ + Class for starting and stopping the server from within the tests + """ + def __init__(self, jar_path): + self.jar_path = jar_path + self.running = False + self.proc = None + + def start(self): + self.proc = subprocess.Popen(['java', '-jar', self.jar_path, "--loglevel", "ERROR"], shell=False) + self.running = True + + def stop(self): + if self.proc: + self.proc.terminate() + self.running = False + + def is_running(self): + # We could check self.proc.poll here instead + return self.running + + +SERVER_SIMULACRON = SimulacronServer(SIMULACRON_JAR) + + +def start_simulacron(): + """ + Starts and waits for simulacron to run + """ + if SERVER_SIMULACRON.is_running(): + SERVER_SIMULACRON.stop() + + SERVER_SIMULACRON.start() + + # TODO improve this sleep, maybe check the logs like ccm + time.sleep(5) + + +def stop_simulacron(): + SERVER_SIMULACRON.stop() + + +class SimulacronClient(object): + def __init__(self, admin_addr="127.0.0.1:8187"): + self.admin_addr = admin_addr + + def submit_request(self, query): + opener = build_opener(HTTPHandler) + data = json.dumps(query.fetch_json()).encode('utf8') + + request = Request("http://{}/{}{}".format( + self.admin_addr, query.path, query.fetch_url_params()), data=data) + request.get_method = lambda: query.method + request.add_header("Content-Type", 'application/json') + request.add_header("Content-Length", len(data)) + + # wait that simulacron is ready and listening + connection = wait_until_not_raised(lambda: opener.open(request), 1, 10) + return connection.read().decode('utf-8') + + def prime_server_versions(self): + """ + This information has to be primed for the test harness to run + """ + system_local_row = {} + system_local_row["cql_version"] = CASSANDRA_VERSION.base_version + system_local_row["release_version"] = CASSANDRA_VERSION.base_version + "-SNAPSHOT" + if DSE_VERSION: + system_local_row["dse_version"] = DSE_VERSION.base_version + column_types = {"cql_version": "ascii", "release_version": "ascii"} + system_local = PrimeQuery("SELECT cql_version, release_version FROM system.local", + rows=[system_local_row], + column_types=column_types) + + self.submit_request(system_local) + + def clear_all_queries(self, cluster_name=DEFAULT_CLUSTER): + """ + Clear all the primed queries from a particular cluster + :param cluster_name: cluster to clear queries from + """ + opener = build_opener(HTTPHandler) + request = Request("http://{0}/{1}/{2}".format( + self.admin_addr, "prime", cluster_name)) + request.get_method = lambda: 'DELETE' + connection = opener.open(request) + return connection.read() + + +NO_THEN = object() + + +class SimulacronRequest(object): + def fetch_json(self): + return {} + + def fetch_url_params(self): + return "" + + @property + def method(self): + raise NotImplementedError() + + +class PrimeOptions(SimulacronRequest): + """ + Class used for specifying how should simulacron respond to an OptionsMessage + """ + def __init__(self, then=None, cluster_name=DEFAULT_CLUSTER): + self.path = "prime/{}".format(cluster_name) + self.then = then + + def fetch_json(self): + json_dict = {} + then = {} + when = {} + + when['request'] = "options" + + if self.then is not None and self.then is not NO_THEN: + then.update(self.then) + + json_dict['when'] = when + if self.then is not NO_THEN: + json_dict['then'] = then + + return json_dict + + def fetch_url_params(self): + return "" + + @property + def method(self): + return "POST" + + +class RejectType(): + UNBIND = "UNBIND" + STOP = "STOP" + REJECT_STARTUP = "REJECT_STARTUP" + + +class RejectConnections(SimulacronRequest): + """ + Class used for making simulacron reject new connections + """ + def __init__(self, reject_type, cluster_name=DEFAULT_CLUSTER): + self.path = "listener/{}".format(cluster_name) + self.reject_type = reject_type + + def fetch_url_params(self): + return "?type={0}".format(self.reject_type) + + @property + def method(self): + return "DELETE" + + +class AcceptConnections(SimulacronRequest): + """ + Class used for making simulacron reject new connections + """ + def __init__(self, cluster_name=DEFAULT_CLUSTER): + self.path = "listener/{}".format(cluster_name) + + @property + def method(self): + return "PUT" + + +class PrimeQuery(SimulacronRequest): + """ + Class used for specifying how should simulacron respond to particular query + """ + def __init__(self, expected_query, result="success", rows=None, + column_types=None, when=None, then=None, cluster_name=DEFAULT_CLUSTER): + self.expected_query = expected_query + self.rows = rows + self.result = result + self.column_types = column_types + self.path = "prime/{}".format(cluster_name) + self.then = then + self.when = when + + def fetch_json(self): + json_dict = {} + then = {} + when = {} + + when['query'] = self.expected_query + then['result'] = self.result + if self.rows is not None: + then['rows'] = self.rows + + if self.column_types is not None: + then['column_types'] = self.column_types + + if self.then is not None and self.then is not NO_THEN: + then.update(self.then) + + if self.then is not NO_THEN: + json_dict['then'] = then + + if self.when is not None: + when.update(self.when) + + json_dict['when'] = when + + return json_dict + + def set_node(self, cluster_id, datacenter_id, node_id): + self.cluster_id = cluster_id + self.datacenter_id = datacenter_id + self.node_id = node_id + + self.path += '/'.join([component for component in + (self.cluster_id, self.datacenter_id, self.node_id) + if component is not None]) + + def fetch_url_params(self): + return "" + + @property + def method(self): + return "POST" + + +class ClusterQuery(SimulacronRequest): + """ + Class used for creating a cluster + """ + def __init__(self, cluster_name, cassandra_version, data_centers="3", json_dict=None, dse_version=None): + self.cluster_name = cluster_name + self.cassandra_version = cassandra_version + self.dse_version = dse_version + self.data_centers = data_centers + if json_dict is None: + self.json_dict = {} + else: + self.json_dict = json_dict + + self.path = "cluster" + + def fetch_json(self): + return self.json_dict + + def fetch_url_params(self): + q = "?cassandra_version={0}&data_centers={1}&name={2}".\ + format(self.cassandra_version, self.data_centers, self.cluster_name) + if self.dse_version: + q += "&dse_version={0}".format(self.dse_version) + + return q + + @property + def method(self): + return "POST" + + +class GetLogsQuery(SimulacronRequest): + """ + Class used to get logs from simulacron + """ + def __init__(self, cluster_name=DEFAULT_CLUSTER, dc_id=0): + self.path = "log/{}/{}".format(cluster_name, dc_id) + + @property + def method(self): + return "GET" + + +class ClearLogsQuery(SimulacronRequest): + """ + Class used to get logs from simulacron + """ + def __init__(self, cluster_name=DEFAULT_CLUSTER, dc_id=0): + self.path = "log/{}/{}".format(cluster_name, dc_id) + + @property + def method(self): + return "DELETE" + + +class _PauseOrResumeReads(SimulacronRequest): + def __init__(self, cluster_name=DEFAULT_CLUSTER, dc_id=None, node_id=None): + self.path = "pause-reads/{}".format(cluster_name) + if dc_id is not None: + self.path += "/{}".format(dc_id) + if node_id is not None: + self.path += "/{}".format(node_id) + elif node_id: + raise Exception("Can't set node_id without dc_id") + + @property + def method(self): + raise NotImplementedError() + + +class PauseReads(_PauseOrResumeReads): + @property + def method(self): + return "PUT" + + +class ResumeReads(_PauseOrResumeReads): + @property + def method(self): + return "DELETE" + + +def prime_driver_defaults(): + """ + Function to prime the necessary queries so the test harness can run + """ + client_simulacron = SimulacronClient() + client_simulacron.prime_server_versions() + + # prepare InvalidResponses for virtual tables + for query in [SchemaParserV4._SELECT_VIRTUAL_KEYSPACES, + SchemaParserV4._SELECT_VIRTUAL_TABLES, + SchemaParserV4._SELECT_VIRTUAL_COLUMNS]: + client_simulacron.submit_request( + PrimeQuery(query, result='invalid', + then={"result": "invalid", + "delay_in_ms": 0, + "ignore_on_prepare": True, + "message": "Invalid Query!"}) + ) + + # prepare empty rows for NGDG + for query in [SchemaParserDSE68._SELECT_VERTICES, + SchemaParserDSE68._SELECT_EDGES]: + client_simulacron.submit_request( + PrimeQuery(query, result='success', + then={'rows': [], 'column_types': {'row1': 'int'}})) + + +def prime_cluster(data_centers="3", version=None, cluster_name=DEFAULT_CLUSTER, dse_version=None): + """ + Creates a new cluster in the simulacron server + :param cluster_name: name of the cluster + :param data_centers: string describing the datacenter, e.g. 2/3 would be two + datacenters of 2 nodes and three nodes + :param version: C* version + """ + version = version or CASSANDRA_VERSION + cluster_query = ClusterQuery(cluster_name, version, data_centers, dse_version=dse_version) + client_simulacron = SimulacronClient() + response = client_simulacron.submit_request(cluster_query) + return SimulacronCluster(response) + + +def start_and_prime_singledc(cluster_name=DEFAULT_CLUSTER): + """ + Starts simulacron and creates a cluster with a single datacenter + :param cluster_name: name of the cluster to start and prime + :return: + """ + return start_and_prime_cluster_defaults(number_of_dc=1, nodes_per_dc=3, cluster_name=cluster_name) + + +def start_and_prime_cluster_defaults(number_of_dc=1, nodes_per_dc=3, version=CASSANDRA_VERSION, + cluster_name=DEFAULT_CLUSTER, dse_version=None): + """ + :param number_of_dc: number of datacentes + :param nodes_per_dc: number of nodes per datacenter + :param version: C* version + """ + start_simulacron() + data_centers = ",".join([str(nodes_per_dc)] * number_of_dc) + simulacron_cluster = prime_cluster(data_centers=data_centers, version=version, + cluster_name=cluster_name, dse_version=dse_version) + prime_driver_defaults() + + return simulacron_cluster + + +default_column_types = { + "key": "bigint", + "value": "ascii" +} + +default_row = {"key": 2, "value": "value"} +default_rows = [default_row] + + +def prime_request(request): + """ + :param request: It could be PrimeQuery class or an PrimeOptions class + """ + return SimulacronClient().submit_request(request) + + +def prime_query(query, rows=default_rows, column_types=default_column_types, when=None, then=None, cluster_name=DEFAULT_CLUSTER): + """ + Shortcut function for priming a query + :return: + """ + # If then is set, then rows and column_types should not + query = PrimeQuery(query, rows=rows, column_types=column_types, when=when, then=then, cluster_name=cluster_name) + response = prime_request(query) + return response + + +def clear_queries(): + """ + Clears all the queries that have been primed to simulacron + """ + SimulacronClient().clear_all_queries() diff --git a/tests/integration/standard/__init__.py b/tests/integration/standard/__init__.py index ba5c6e26fd..13d6eb6071 100644 --- a/tests/integration/standard/__init__.py +++ b/tests/integration/standard/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest try: from ccmlib import common diff --git a/tests/integration/standard/column_encryption/test_policies.py b/tests/integration/standard/column_encryption/test_policies.py new file mode 100644 index 0000000000..0d692ac5c1 --- /dev/null +++ b/tests/integration/standard/column_encryption/test_policies.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from tests.integration import use_singledc, TestCluster + +from cassandra.policies import ColDesc + +from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, \ + AES256_KEY_SIZE_BYTES, AES256_BLOCK_SIZE_BYTES + + +def setup_module(): + use_singledc() + + +class ColumnEncryptionPolicyTest(unittest.TestCase): + + def _recreate_keyspace(self, session): + session.execute("drop keyspace if exists foo") + session.execute("CREATE KEYSPACE foo WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") + session.execute("CREATE TABLE foo.bar(encrypted blob, unencrypted int, primary key(unencrypted))") + + def _create_policy(self, key, iv = None): + cl_policy = AES256ColumnEncryptionPolicy() + col_desc = ColDesc('foo','bar','encrypted') + cl_policy.add_column(col_desc, key, "int") + return (col_desc, cl_policy) + + def test_end_to_end_prepared(self): + + # We only currently perform testing on a single type/expected value pair since CLE functionality is essentially + # independent of the underlying type. We intercept data after it's been encoded when it's going out and before it's + # encoded when coming back; the actual types of the data involved don't impact us. + expected = 0 + + key = os.urandom(AES256_KEY_SIZE_BYTES) + (_, cl_policy) = self._create_policy(key) + cluster = TestCluster(column_encryption_policy=cl_policy) + session = cluster.connect() + self._recreate_keyspace(session) + + prepared = session.prepare("insert into foo.bar (encrypted, unencrypted) values (?,?)") + for i in range(100): + session.execute(prepared, (i, i)) + + # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted + # values here to confirm that we don't interfere with regular processing of unencrypted vals. + (encrypted, unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + self.assertEqual(expected, encrypted) + self.assertEqual(expected, unencrypted) + + # Confirm the same behaviour from a subsequent prepared statement as well + prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") + (encrypted, unencrypted) = session.execute(prepared, [expected]).one() + self.assertEqual(expected, encrypted) + self.assertEqual(expected, unencrypted) + + def test_end_to_end_simple(self): + + expected = 1 + + key = os.urandom(AES256_KEY_SIZE_BYTES) + (col_desc, cl_policy) = self._create_policy(key) + cluster = TestCluster(column_encryption_policy=cl_policy) + session = cluster.connect() + self._recreate_keyspace(session) + + # Use encode_and_encrypt helper function to populate date + for i in range(1, 100): + self.assertIsNotNone(i) + encrypted = cl_policy.encode_and_encrypt(col_desc, i) + session.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) + + # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted + # values here to confirm that we don't interfere with regular processing of unencrypted vals. + (encrypted, unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + self.assertEqual(expected, encrypted) + self.assertEqual(expected, unencrypted) + + # Confirm the same behaviour from a subsequent prepared statement as well + prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") + (encrypted, unencrypted) = session.execute(prepared, [expected]).one() + self.assertEqual(expected, encrypted) + self.assertEqual(expected, unencrypted) + + def test_end_to_end_different_cle_contexts_different_ivs(self): + """ + Test to validate PYTHON-1350. We should be able to decode the data from two different contexts (with two different IVs) + since the IV used to decrypt the data is actually now stored with the data. + """ + + expected = 2 + + key = os.urandom(AES256_KEY_SIZE_BYTES) + + # Simulate the creation of two AES256 policies at two different times. Python caches + # default param args at function definition time so a single value will be used any time + # the default val is used. Upshot is that within the same test we'll always have the same + # IV if we rely on the default args, so manually introduce some variation here to simulate + # what actually happens if you have two distinct sessions created at two different times. + iv1 = os.urandom(AES256_BLOCK_SIZE_BYTES) + (col_desc1, cl_policy1) = self._create_policy(key, iv=iv1) + cluster1 = TestCluster(column_encryption_policy=cl_policy1) + session1 = cluster1.connect() + self._recreate_keyspace(session1) + + # Use encode_and_encrypt helper function to populate date + for i in range(1, 100): + self.assertIsNotNone(i) + encrypted = cl_policy1.encode_and_encrypt(col_desc1, i) + session1.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) + session1.shutdown() + cluster1.shutdown() + + # Explicitly clear the class-level cache here; we're trying to simulate a second connection from a completely new process and + # that would entail not re-using any cached ciphers + AES256ColumnEncryptionPolicy._build_cipher.cache_clear() + cache_info = cl_policy1.cache_info() + self.assertEqual(cache_info.currsize, 0) + + iv2 = os.urandom(AES256_BLOCK_SIZE_BYTES) + (_, cl_policy2) = self._create_policy(key, iv=iv2) + cluster2 = TestCluster(column_encryption_policy=cl_policy2) + session2 = cluster2.connect() + (encrypted, unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + self.assertEqual(expected, encrypted) + self.assertEqual(expected, unencrypted) + + def test_end_to_end_different_cle_contexts_different_policies(self): + """ + Test to validate PYTHON-1356. Class variables used to pass CLE policy down to protocol handler shouldn't persist. + """ + + expected = 3 + + key = os.urandom(AES256_KEY_SIZE_BYTES) + (col_desc, cl_policy) = self._create_policy(key) + cluster = TestCluster(column_encryption_policy=cl_policy) + session = cluster.connect() + self._recreate_keyspace(session) + + # Use encode_and_encrypt helper function to populate date + session.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (cl_policy.encode_and_encrypt(col_desc, expected), expected)) + + # We now open a new session _without_ the CLE policy specified. We should _not_ be able to read decrypted bits from this session. + cluster2 = TestCluster() + session2 = cluster2.connect() + + # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted + # values here to confirm that we don't interfere with regular processing of unencrypted vals. + (encrypted, unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() + self.assertEqual(cl_policy.encode_and_encrypt(col_desc, expected), encrypted) + self.assertEqual(expected, unencrypted) + + # Confirm the same behaviour from a subsequent prepared statement as well + prepared = session2.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") + (encrypted, unencrypted) = session2.execute(prepared, [expected]).one() + self.assertEqual(cl_policy.encode_and_encrypt(col_desc, expected), encrypted) diff --git a/tests/integration/standard/conftest.py b/tests/integration/standard/conftest.py new file mode 100644 index 0000000000..6028c2a06d --- /dev/null +++ b/tests/integration/standard/conftest.py @@ -0,0 +1,13 @@ +import pytest +import logging + +# from https://github.com/streamlit/streamlit/pull/5047/files +def pytest_sessionfinish(): + # We're not waiting for scriptrunner threads to cleanly close before ending the PyTest, + # which results in raised exception ValueError: I/O operation on closed file. + # This is well known issue in PyTest, check out these discussions for more: + # * https://github.com/pytest-dev/pytest/issues/5502 + # * https://github.com/pytest-dev/pytest/issues/5282 + # To prevent the exception from being raised on pytest_sessionfinish + # we disable exception raising in logging module + logging.raiseExceptions = False \ No newline at end of file diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index 473b398a12..2d47a93529 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,46 +14,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +from packaging.version import Version import logging import time -from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.cluster import NoHostAvailable from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider -from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION +from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION, \ + CASSANDRA_IP, CASSANDRA_VERSION, USE_CASS_EXTERNAL, start_cluster_wait_for_up, TestCluster from tests.integration.util import assert_quiescent_pool_state -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest log = logging.getLogger(__name__) -def setup_module(): - use_singledc(start=False) - ccm_cluster = get_cluster() - ccm_cluster.stop() - config_options = {'authenticator': 'PasswordAuthenticator', - 'authorizer': 'CassandraAuthorizer'} - ccm_cluster.set_configuration_options(config_options) - log.debug("Starting ccm test cluster with %s", config_options) - ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) - # there seems to be some race, with some versions of C* taking longer to - # get the auth (and default user) setup. Sleep here to give it a chance - time.sleep(10) +#This can be tested for remote hosts, but the cluster has to be configured accordingly +#@local +def setup_module(): + if CASSANDRA_IP.startswith("127.0.0.") and not USE_CASS_EXTERNAL: + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + config_options = {'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer'} + ccm_cluster.set_configuration_options(config_options) + log.debug("Starting ccm test cluster with %s", config_options) + start_cluster_wait_for_up(ccm_cluster) + + # PYTHON-1328 + # + # Give the cluster enough time to startup (and perform necessary initialization) + # before executing the test. + if CASSANDRA_VERSION > Version('4.0-a'): + time.sleep(10) + def teardown_module(): remove_cluster() # this test messes with config class AuthenticationTests(unittest.TestCase): + """ Tests to cover basic authentication functionality """ - def get_authentication_provider(self, username, password): """ Return correct authentication provider based on protocol version. @@ -69,11 +78,25 @@ def get_authentication_provider(self, username, password): return PlainTextAuthProvider(username=username, password=password) def cluster_as(self, usr, pwd): - return Cluster(protocol_version=PROTOCOL_VERSION, - idle_heartbeat_interval=0, - auth_provider=self.get_authentication_provider(username=usr, password=pwd)) + # test we can connect at least once with creds + # to ensure the role manager is setup + for _ in range(5): + try: + cluster = TestCluster( + idle_heartbeat_interval=0, + auth_provider=self.get_authentication_provider(username='cassandra', password='cassandra')) + cluster.connect(wait_for_all_pools=True) + + return TestCluster( + idle_heartbeat_interval=0, + auth_provider=self.get_authentication_provider(username=usr, password=pwd)) + except Exception as e: + time.sleep(5) + + raise Exception('Unable to connect with creds: {}/{}'.format(usr, pwd)) def test_auth_connect(self): + user = 'u' passwd = 'password' @@ -82,10 +105,10 @@ def test_auth_connect(self): try: cluster = self.cluster_as(user, passwd) - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) try: self.assertTrue(session.execute('SELECT release_version FROM system.local')) - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(self, cluster, wait=1) for pool in session.get_pools(): connection, _ = pool.borrow_connection(timeout=0) self.assertEqual(connection.authenticator.server_authenticator_class, 'org.apache.cassandra.auth.PasswordAuthenticator') @@ -94,50 +117,54 @@ def test_auth_connect(self): cluster.shutdown() finally: root_session.execute('DROP USER %s', user) - assert_quiescent_pool_state(self, root_session.cluster) + assert_quiescent_pool_state(self, root_session.cluster, wait=1) root_session.cluster.shutdown() def test_connect_wrong_pwd(self): cluster = self.cluster_as('cassandra', 'wrong_pass') - self.assertRaisesRegexp(NoHostAvailable, - '.*AuthenticationFailed.*Bad credentials.*Username and/or ' - 'password are incorrect.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) - cluster.shutdown() + try: + self.assertRaisesRegex(NoHostAvailable, + '.*AuthenticationFailed.', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() def test_connect_wrong_username(self): cluster = self.cluster_as('wrong_user', 'cassandra') - self.assertRaisesRegexp(NoHostAvailable, - '.*AuthenticationFailed.*Bad credentials.*Username and/or ' - 'password are incorrect.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) - cluster.shutdown() + try: + self.assertRaisesRegex(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() def test_connect_empty_pwd(self): cluster = self.cluster_as('Cassandra', '') - self.assertRaisesRegexp(NoHostAvailable, - '.*AuthenticationFailed.*Bad credentials.*Username and/or ' - 'password are incorrect.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) - cluster.shutdown() + try: + self.assertRaisesRegex(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() def test_connect_no_auth_provider(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - self.assertRaisesRegexp(NoHostAvailable, - '.*AuthenticationFailed.*Remote end requires authentication.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) - cluster.shutdown() + cluster = TestCluster() + try: + self.assertRaisesRegex(NoHostAvailable, + '.*AuthenticationFailed.*', + cluster.connect) + assert_quiescent_pool_state(self, cluster) + finally: + cluster.shutdown() class SaslAuthenticatorTests(AuthenticationTests): """ Test SaslAuthProvider as PlainText """ - def setUp(self): if PROTOCOL_VERSION < 2: raise unittest.SkipTest('Sasl authentication not available for protocol v1') diff --git a/tests/integration/standard/test_authentication_misconfiguration.py b/tests/integration/standard/test_authentication_misconfiguration.py new file mode 100644 index 0000000000..a2e2c019a5 --- /dev/null +++ b/tests/integration/standard/test_authentication_misconfiguration.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tests.integration import USE_CASS_EXTERNAL, use_cluster, TestCluster + + +class MisconfiguredAuthenticationTests(unittest.TestCase): + """ One node (not the contact point) has password auth. The rest of the nodes have no auth """ + @classmethod + def setUpClass(cls): + if not USE_CASS_EXTERNAL: + ccm_cluster = use_cluster(cls.__name__, [3], start=False) + node3 = ccm_cluster.nodes['node3'] + node3.set_configuration_options(values={ + 'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer', + }) + ccm_cluster.start(wait_for_binary_proto=True) + + cls.ccm_cluster = ccm_cluster + + def test_connect_no_auth_provider(self): + cluster = TestCluster() + cluster.connect() + cluster.refresh_nodes() + down_hosts = [host for host in cluster.metadata.all_hosts() if not host.is_up] + self.assertEqual(len(down_hosts), 1) + cluster.shutdown() + + @classmethod + def tearDownClass(cls): + if not USE_CASS_EXTERNAL: + cls.ccm_cluster.stop() diff --git a/tests/integration/standard/test_client_warnings.py b/tests/integration/standard/test_client_warnings.py index a463578f4e..d20251772a 100644 --- a/tests/integration/standard/test_client_warnings.py +++ b/tests/integration/standard/test_client_warnings.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +15,11 @@ # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from cassandra.query import BatchStatement -from cassandra.cluster import Cluster -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, local, TestCluster def setup_module(): @@ -35,7 +33,7 @@ def setUpClass(cls): if PROTOCOL_VERSION < 4: return - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls.session.execute("CREATE TABLE IF NOT EXISTS test1rf.client_warning (k int, v0 int, v1 int, PRIMARY KEY (k, v0))") @@ -74,7 +72,7 @@ def test_warning_basic(self): future = self.session.execute_async(self.warn_batch) future.result() self.assertEqual(len(future.warnings), 1) - self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') def test_warning_with_trace(self): """ @@ -90,9 +88,10 @@ def test_warning_with_trace(self): future = self.session.execute_async(self.warn_batch, trace=True) future.result() self.assertEqual(len(future.warnings), 1) - self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') self.assertIsNotNone(future.get_query_trace()) + @local def test_warning_with_custom_payload(self): """ Test to validate client warning with custom payload @@ -108,9 +107,10 @@ def test_warning_with_custom_payload(self): future = self.session.execute_async(self.warn_batch, custom_payload=payload) future.result() self.assertEqual(len(future.warnings), 1) - self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') self.assertDictEqual(future.custom_payload, payload) + @local def test_warning_with_trace_and_custom_payload(self): """ Test to validate client warning with tracing and client warning @@ -126,6 +126,6 @@ def test_warning_with_trace_and_custom_payload(self): future = self.session.execute_async(self.warn_batch, trace=True, custom_payload=payload) future.result() self.assertEqual(len(future.warnings), 1) - self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') + self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') self.assertIsNotNone(future.get_query_trace()) self.assertDictEqual(future.custom_payload, payload) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 003c78a913..c6fc2a717f 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,35 +14,121 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from collections import deque -from mock import patch +from copy import copy +from unittest.mock import Mock, call, patch, ANY import time from uuid import uuid4 +import logging +import warnings +from packaging.version import Version import cassandra -from cassandra.cluster import Cluster, NoHostAvailable +from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, ControlConnection, Cluster from cassandra.concurrent import execute_concurrent from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance, - WhiteListRoundRobinPolicy) -from cassandra.protocol import MAX_SUPPORTED_VERSION -from cassandra.query import SimpleStatement, TraceUnavailable - -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry + AddressTranslator, TokenAwarePolicy, HostFilterPolicy) +from cassandra import ConsistencyLevel + +from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory +from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider +from cassandra import connection +from cassandra.connection import DefaultEndPoint + +from tests import notwindows +from tests.integration import use_singledc, get_server_versions, CASSANDRA_VERSION, \ + execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ + get_unsupported_upper_protocol, protocolv6, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40, \ + DSE_VERSION, TestCluster, PROTOCOL_VERSION from tests.integration.util import assert_quiescent_pool_state +import sys + +log = logging.getLogger(__name__) def setup_module(): use_singledc() + warnings.simplefilter("always") + + +class IgnoredHostPolicy(RoundRobinPolicy): + + def __init__(self, ignored_hosts): + self.ignored_hosts = ignored_hosts + RoundRobinPolicy.__init__(self) + + def distance(self, host): + if(host.address in self.ignored_hosts): + return HostDistance.IGNORED + else: + return HostDistance.LOCAL class ClusterTests(unittest.TestCase): + @local + def test_ignored_host_up(self): + """ + Test to ensure that is_up is not set by default on ignored hosts + + @since 3.6 + @jira_ticket PYTHON-551 + @expected_result ignored hosts should have None set for is_up + + @test_category connection + """ + ignored_host_policy = IgnoredHostPolicy(["127.0.0.2", "127.0.0.3"]) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=ignored_host_policy)} + ) + cluster.connect() + for host in cluster.metadata.all_hosts(): + if str(host) == "127.0.0.1:9042": + self.assertTrue(host.is_up) + else: + self.assertIsNone(host.is_up) + cluster.shutdown() + + @local + def test_host_resolution(self): + """ + Test to insure A records are resolved appropriately. + + @since 3.3 + @jira_ticket PYTHON-415 + @expected_result hostname will be transformed into IP + + @test_category connection + """ + cluster = TestCluster(contact_points=["localhost"], connect_timeout=1) + self.assertTrue(DefaultEndPoint('127.0.0.1') in cluster.endpoints_resolved) + + @local + def test_host_duplication(self): + """ + Ensure that duplicate hosts in the contact points are surfaced in the cluster metadata + + @since 3.3 + @jira_ticket PYTHON-103 + @expected_result duplicate hosts aren't surfaced in cluster.metadata + + @test_category connection + """ + cluster = TestCluster( + contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], + connect_timeout=1 + ) + cluster.connect(wait_for_all_pools=True) + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + cluster = TestCluster(contact_points=["127.0.0.1", "localhost"], connect_timeout=1) + cluster.connect(wait_for_all_pools=True) + self.assertEqual(len(cluster.metadata.all_hosts()), 3) + cluster.shutdown() + @local def test_raise_error_on_control_connection_timeout(self): """ Test for initial control connection timeout @@ -59,10 +147,11 @@ def test_raise_error_on_control_connection_timeout(self): """ get_node(1).pause() - cluster = Cluster(contact_points=['127.0.0.1'], protocol_version=PROTOCOL_VERSION, connect_timeout=1) + cluster = TestCluster(contact_points=['127.0.0.1'], connect_timeout=1) - with self.assertRaisesRegexp(NoHostAvailable, "OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): + with self.assertRaisesRegex(NoHostAvailable, "OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): cluster.connect() + cluster.shutdown() get_node(1).resume() @@ -71,7 +160,7 @@ def test_basic(self): Test basic connection and usage """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() result = execute_until_pass(session, """ @@ -104,6 +193,45 @@ def test_basic(self): cluster.shutdown() + def test_session_host_parameter(self): + """ + Test for protocol negotiation + + Very that NoHostAvailable is risen in Session.__init__ when there are no valid connections and that + no error is arisen otherwise, despite maybe being some invalid hosts + + @since 3.9 + @jira_ticket PYTHON-665 + @expected_result NoHostAvailable when the driver is unable to connect to a valid host, + no exception otherwise + + @test_category connection + """ + def cleanup(): + """ + When this test fails, the inline .shutdown() calls don't get + called, so we register this as a cleanup. + """ + self.cluster_to_shutdown.shutdown() + self.addCleanup(cleanup) + + # Test with empty list + self.cluster_to_shutdown = TestCluster(contact_points=[]) + with self.assertRaises(NoHostAvailable): + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + + # Test with only invalid + self.cluster_to_shutdown = TestCluster(contact_points=('1.2.3.4',)) + with self.assertRaises(NoHostAvailable): + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + + # Test with valid and invalid hosts + self.cluster_to_shutdown = TestCluster(contact_points=("127.0.0.1", "127.0.0.2", "1.2.3.4")) + self.cluster_to_shutdown.connect() + self.cluster_to_shutdown.shutdown() + def test_protocol_negotiation(self): """ Test for protocol negotiation @@ -121,18 +249,36 @@ def test_protocol_negotiation(self): """ cluster = Cluster() - self.assertEqual(cluster.protocol_version, MAX_SUPPORTED_VERSION) + self.assertLessEqual(cluster.protocol_version, cassandra.ProtocolVersion.MAX_SUPPORTED) session = cluster.connect() updated_protocol_version = session._protocol_version updated_cluster_version = cluster.protocol_version # Make sure the correct protocol was selected by default - if CASSANDRA_VERSION >= '2.2': + if DSE_VERSION and DSE_VERSION >= Version("6.0"): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.DSE_V2) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.DSE_V2) + elif DSE_VERSION and DSE_VERSION >= Version("5.1"): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.DSE_V1) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.DSE_V1) + elif CASSANDRA_VERSION >= Version('4.0-beta5'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V5) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V5) + elif CASSANDRA_VERSION >= Version('4.0-a'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + elif CASSANDRA_VERSION >= Version('3.11'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + elif CASSANDRA_VERSION >= Version('3.0'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + elif CASSANDRA_VERSION >= Version('2.2'): self.assertEqual(updated_protocol_version, 4) self.assertEqual(updated_cluster_version, 4) - elif CASSANDRA_VERSION >= '2.1': + elif CASSANDRA_VERSION >= Version('2.1'): self.assertEqual(updated_protocol_version, 3) self.assertEqual(updated_cluster_version, 3) - elif CASSANDRA_VERSION >= '2.0': + elif CASSANDRA_VERSION >= Version('2.0'): self.assertEqual(updated_protocol_version, 2) self.assertEqual(updated_cluster_version, 2) else: @@ -141,30 +287,68 @@ def test_protocol_negotiation(self): cluster.shutdown() + def test_invalid_protocol_negotation(self): + """ + Test for protocol negotiation when explicit versions are set + + If an explicit protocol version that is not compatible with the server version is set + an exception should be thrown. It should not attempt to negotiate + + for reference supported protocol version to server versions is as follows/ + + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 + + @since 3.6.0 + @jira_ticket PYTHON-537 + @expected_result downgrading should not be allowed when explicit protocol versions are set. + + @test_category connection + """ + + upper_bound = get_unsupported_upper_protocol() + log.debug('got upper_bound of {}'.format(upper_bound)) + if upper_bound is not None: + cluster = TestCluster(protocol_version=upper_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + + lower_bound = get_unsupported_lower_protocol() + log.debug('got lower_bound of {}'.format(lower_bound)) + if lower_bound is not None: + cluster = TestCluster(protocol_version=lower_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + def test_connect_on_keyspace(self): """ Ensure clusters that connect on a keyspace, do """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() result = session.execute( """ - INSERT INTO test3rf.test (k, v) VALUES (8889, 8889) + INSERT INTO test1rf.test (k, v) VALUES (8889, 8889) """) self.assertFalse(result) - result = session.execute("SELECT * FROM test3rf.test") - self.assertEqual([(8889, 8889)], result) + result = session.execute("SELECT * FROM test1rf.test") + self.assertEqual([(8889, 8889)], result, "Rows in ResultSet are {0}".format(result.current_rows)) # test_connect_on_keyspace - session2 = cluster.connect('test3rf') + session2 = cluster.connect('test1rf') result2 = session2.execute("SELECT * FROM test") self.assertEqual(result, result2) cluster.shutdown() def test_set_keyspace_twice(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() session.execute("USE system") session.execute("USE system") @@ -175,10 +359,8 @@ def test_default_connections(self): Ensure errors are not thrown when using non-default policies """ - Cluster( - load_balancing_policy=RoundRobinPolicy(), + TestCluster( reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), - default_retry_policy=RetryPolicy(), conviction_policy_factory=SimpleConvictionPolicy, protocol_version=PROTOCOL_VERSION ) @@ -187,7 +369,7 @@ def test_connect_to_already_shutdown_cluster(self): """ Ensure you cannot connect to a cluster that's been shutdown """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() cluster.shutdown() self.assertRaises(Exception, cluster.connect) @@ -196,7 +378,7 @@ def test_auth_provider_is_callable(self): Ensure that auth_providers are always callable """ self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1) - c = Cluster(protocol_version=1) + c = TestCluster(protocol_version=1) self.assertRaises(TypeError, setattr, c, 'auth_provider', 1) def test_v2_auth_provider(self): @@ -205,7 +387,7 @@ def test_v2_auth_provider(self): """ bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'} self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2) - c = Cluster(protocol_version=2) + c = TestCluster(protocol_version=2) self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider) def test_conviction_policy_factory_is_callable(self): @@ -221,8 +403,8 @@ def test_connect_to_bad_hosts(self): when a cluster cannot connect to given hosts """ - cluster = Cluster(['127.1.2.9', '127.1.2.10'], - protocol_version=PROTOCOL_VERSION) + cluster = TestCluster(contact_points=['127.1.2.9', '127.1.2.10'], + protocol_version=PROTOCOL_VERSION) self.assertRaises(NoHostAvailable, cluster.connect) def test_cluster_settings(self): @@ -232,7 +414,7 @@ def test_cluster_settings(self): if PROTOCOL_VERSION >= 3: raise unittest.SkipTest("min/max requests and core/max conns aren't used with v3 protocol") - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() min_requests_per_connection = cluster.get_min_requests_per_connection(HostDistance.LOCAL) self.assertEqual(cassandra.cluster.DEFAULT_MIN_REQUESTS, min_requests_per_connection) @@ -255,7 +437,7 @@ def test_cluster_settings(self): self.assertEqual(cluster.get_max_connections_per_host(HostDistance.LOCAL), max_connections_per_host + 1) def test_refresh_schema(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() original_meta = cluster.metadata.keyspaces @@ -267,7 +449,7 @@ def test_refresh_schema(self): cluster.shutdown() def test_refresh_schema_keyspace(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() original_meta = cluster.metadata.keyspaces @@ -283,7 +465,7 @@ def test_refresh_schema_keyspace(self): cluster.shutdown() def test_refresh_schema_table(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() original_meta = cluster.metadata.keyspaces @@ -309,7 +491,7 @@ def test_refresh_schema_type(self): raise unittest.SkipTest('UDTs are not specified in change events for protocol v2') # We may want to refresh types on keyspace change events in that case(?) - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() keyspace_name = 'test1rf' @@ -329,35 +511,37 @@ def test_refresh_schema_type(self): self.assertEqual(original_test1rf_meta.export_as_string(), current_test1rf_meta.export_as_string()) self.assertIsNot(original_type_meta, current_type_meta) self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query()) - session.shutdown() + cluster.shutdown() + @local + @notwindows def test_refresh_schema_no_wait(self): + original_wait_for_responses = connection.Connection.wait_for_responses - contact_points = ['127.0.0.1'] - cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, - contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points)) - session = cluster.connect() + def patched_wait_for_responses(*args, **kwargs): + # When selecting schema version, replace the real schema UUID with an unexpected UUID + response = original_wait_for_responses(*args, **kwargs) + if len(args) > 2 and hasattr(args[2], "query") and args[2].query == "SELECT schema_version FROM system.local WHERE key='local'": + new_uuid = uuid4() + response[1].parsed_rows[0] = (new_uuid,) + return response - schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0] - new_schema_ver = uuid4() - session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (new_schema_ver,)) - - try: + with patch.object(connection.Connection, "wait_for_responses", patched_wait_for_responses): agreement_timeout = 1 # cluster agreement wait exceeded - c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=agreement_timeout) + c = TestCluster(max_schema_agreement_wait=agreement_timeout) c.connect() self.assertTrue(c.metadata.keyspaces) # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() - self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) + self.assertRaisesRegex(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) - + # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() @@ -371,7 +555,7 @@ def test_refresh_schema_no_wait(self): refresh_threshold = 0.5 # cluster agreement bypass - c = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=0) + c = TestCluster(max_schema_agreement_wait=0) start_time = time.time() s = c.connect() end_time = time.time() @@ -386,45 +570,32 @@ def test_refresh_schema_no_wait(self): self.assertLess(end_time - start_time, refresh_threshold) self.assertIsNot(original_meta, c.metadata.keyspaces) self.assertEqual(original_meta, c.metadata.keyspaces) - + # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() - self.assertRaisesRegexp(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, + self.assertRaisesRegex(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, max_schema_agreement_wait=agreement_timeout) end_time = time.time() self.assertGreaterEqual(end_time - start_time, agreement_timeout) self.assertIs(original_meta, c.metadata.keyspaces) c.shutdown() - finally: - # TODO once fixed this connect call - session = cluster.connect() - session.execute("UPDATE system.local SET schema_version=%s WHERE key='local'", (schema_ver,)) - - cluster.shutdown() def test_trace(self): """ Ensure trace can be requested for async and non-async queries """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() - def check_trace(trace): - self.assertIsNotNone(trace.request_type) - self.assertIsNotNone(trace.duration) - self.assertIsNotNone(trace.started_at) - self.assertIsNotNone(trace.coordinator) - self.assertIsNotNone(trace.events) - result = session.execute( "SELECT * FROM system.local", trace=True) - check_trace(result.get_query_trace()) + self._check_trace(result.get_query_trace()) query = "SELECT * FROM system.local" statement = SimpleStatement(query) result = session.execute(statement, trace=True) - check_trace(result.get_query_trace()) + self._check_trace(result.get_query_trace()) query = "SELECT * FROM system.local" statement = SimpleStatement(query) @@ -434,7 +605,7 @@ def check_trace(trace): statement2 = SimpleStatement(query) future = session.execute_async(statement2, trace=True) future.result() - check_trace(future.get_query_trace()) + self._check_trace(future.get_query_trace()) statement2 = SimpleStatement(query) future = session.execute_async(statement2) @@ -444,26 +615,75 @@ def check_trace(trace): prepared = session.prepare("SELECT * FROM system.local") future = session.execute_async(prepared, parameters=(), trace=True) future.result() - check_trace(future.get_query_trace()) + self._check_trace(future.get_query_trace()) cluster.shutdown() - def test_trace_timeout(self): - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + def test_trace_unavailable(self): + """ + First checks that TraceUnavailable is arisen if the + max_wait parameter is negative + + Then checks that TraceUnavailable is arisen if the + result hasn't been set yet + + @since 3.10 + @jira_ticket PYTHON-196 + @expected_result TraceUnavailable is arisen in both cases + + @test_category query + """ + cluster = TestCluster() + self.addCleanup(cluster.shutdown) session = cluster.connect() query = "SELECT * FROM system.local" statement = SimpleStatement(query) - future = session.execute_async(statement, trace=True) - future.result() - self.assertRaises(TraceUnavailable, future.get_query_trace, -1.0) - cluster.shutdown() + + max_retry_count = 10 + for i in range(max_retry_count): + future = session.execute_async(statement, trace=True) + future.result() + try: + result = future.get_query_trace(-1.0) + # In case the result has time to come back before this timeout due to a race condition + self._check_trace(result) + except TraceUnavailable: + break + else: + raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) + + + for i in range(max_retry_count): + future = session.execute_async(statement, trace=True) + try: + result = future.get_query_trace(max_wait=120) + # In case the result has been set check the trace + self._check_trace(result) + except TraceUnavailable: + break + else: + raise Exception("get_query_trace didn't raise TraceUnavailable after {} tries".format(max_retry_count)) + + def test_one_returns_none(self): + """ + Test ResulSet.one returns None if no rows where found + + @since 3.14 + @jira_ticket PYTHON-947 + @expected_result ResulSet.one is None + + @test_category query + """ + with TestCluster() as cluster: + session = cluster.connect() + self.assertIsNone(session.execute("SELECT * from system.local WHERE key='madeup_key'").one()) def test_string_coverage(self): """ Ensure str(future) returns without error """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() query = "SELECT * FROM system.local" @@ -477,19 +697,75 @@ def test_string_coverage(self): self.assertIn('result', str(future)) cluster.shutdown() + def test_can_connect_with_plainauth(self): + """ + Verify that we can connect setting PlainTextAuthProvider against a + C* server without authentication set. We also verify a warning is + issued per connection. This test is here instead of in test_authentication.py + because the C* server running in that module has auth set. + + @since 3.14 + @jira_ticket PYTHON-940 + @expected_result we can connect, query C* and warning are issued + + @test_category auth + """ + auth_provider = PlainTextAuthProvider( + username="made_up_username", + password="made_up_password" + ) + self._warning_are_issued_when_auth(auth_provider) + + def test_can_connect_with_sslauth(self): + """ + Verify that we can connect setting SaslAuthProvider against a + C* server without authentication set. We also verify a warning is + issued per connection. This test is here instead of in test_authentication.py + because the C* server running in that module has auth set. + + @since 3.14 + @jira_ticket PYTHON-940 + @expected_result we can connect, query C* and warning are issued + + @test_category auth + """ + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN', + 'qops': ['auth'], + 'username': "made_up_username", + 'password': "made_up_password"} + + auth_provider = SaslAuthProvider(**sasl_kwargs) + self._warning_are_issued_when_auth(auth_provider) + + def _warning_are_issued_when_auth(self, auth_provider): + with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: + with TestCluster(auth_provider=auth_provider) as cluster: + session = cluster.connect() + self.assertIsNotNone(session.execute("SELECT * from system.local")) + + # Three conenctions to nodes plus the control connection + auth_warning = mock_handler.get_message_count('warning', "An authentication challenge was not sent") + self.assertGreaterEqual(auth_warning, 4) + self.assertEqual( + auth_warning, + mock_handler.get_message_count("debug", "Got ReadyMessage on new connection") + ) + def test_idle_heartbeat(self): interval = 2 - cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval) + cluster = TestCluster(idle_heartbeat_interval=interval, + monitor_reporting_enabled=False) if PROTOCOL_VERSION < 3: cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) - # This test relies on impl details of connection req id management to see if heartbeats + # This test relies on impl details of connection req id management to see if heartbeats # are being sent. May need update if impl is changed connection_request_ids = {} for h in cluster.get_connection_holders(): for c in h.get_connections(): - # make sure none are idle (should have startup messages) + # make sure none are idle (should have startup messages self.assertFalse(c.is_idle) with c.lock: connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids @@ -524,7 +800,7 @@ def test_idle_heartbeat(self): self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc # include additional sessions - session2 = cluster.connect() + session2 = cluster.connect(wait_for_all_pools=True) holders = cluster.get_connection_holders() self.assertIn(cluster.control_connection, holders) @@ -541,7 +817,7 @@ def test_idle_heartbeat_disabled(self): self.assertTrue(Cluster.idle_heartbeat_interval) # heartbeat disabled with '0' - cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) + cluster = TestCluster(idle_heartbeat_interval=0) self.assertEqual(cluster.idle_heartbeat_interval, 0) session = cluster.connect() @@ -557,7 +833,7 @@ def test_idle_heartbeat_disabled(self): def test_pool_management(self): # Ensure that in_flight and request_ids quiesce after cluster operations - cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat + cluster = TestCluster(idle_heartbeat_interval=0) # no idle heartbeat here, pool management is tested in test_idle_heartbeat session = cluster.connect() session2 = cluster.connect() @@ -584,4 +860,732 @@ def test_pool_management(self): cluster.shutdown() + @local + def test_profile_load_balancing(self): + """ + Tests that profile load balancing policies are honored. + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result Execution Policy should be used when applicable. + + @test_category config_profiles + """ + query = "select release_version from system.local" + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) + ) + with TestCluster(execution_profiles={'node1': node1}, monitor_reporting_enabled=False) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + # default is DCA RR for all hosts + expected_hosts = set(cluster.metadata.all_hosts()) + queried_hosts = set() + for _ in expected_hosts: + rs = session.execute(query) + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + + # by name we should only hit the one + expected_hosts = set(h for h in cluster.metadata.all_hosts() if h.address == CASSANDRA_IP) + queried_hosts = set() + for _ in cluster.metadata.all_hosts(): + rs = session.execute(query, execution_profile='node1') + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + + # use a copied instance and override the row factory + # assert last returned value can be accessed as a namedtuple so we can prove something different + named_tuple_row = rs[0] + self.assertIsInstance(named_tuple_row, tuple) + self.assertTrue(named_tuple_row.release_version) + + tmp_profile = copy(node1) + tmp_profile.row_factory = tuple_factory + queried_hosts = set() + for _ in cluster.metadata.all_hosts(): + rs = session.execute(query, execution_profile=tmp_profile) + queried_hosts.add(rs.response_future._current_host) + self.assertEqual(queried_hosts, expected_hosts) + tuple_row = rs[0] + self.assertIsInstance(tuple_row, tuple) + with self.assertRaises(AttributeError): + tuple_row.release_version + + # make sure original profile is not impacted + self.assertTrue(session.execute(query, execution_profile='node1')[0].release_version) + + def test_setting_lbp_legacy(self): + cluster = TestCluster() + self.addCleanup(cluster.shutdown) + cluster.load_balancing_policy = RoundRobinPolicy() + self.assertEqual( + list(cluster.load_balancing_policy.make_query_plan()), [] + ) + cluster.connect() + self.assertNotEqual( + list(cluster.load_balancing_policy.make_query_plan()), [] + ) + + def test_profile_lb_swap(self): + """ + Tests that profile load balancing policies are not shared + + Creates two LBP, runs a few queries, and validates that each LBP is execised + seperately between EP's + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result LBP should not be shared. + + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1, 'rr2': rr2} + with TestCluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + # default is DCA RR for all hosts + expected_hosts = set(cluster.metadata.all_hosts()) + rr1_queried_hosts = set() + rr2_queried_hosts = set() + + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.add(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr2') + rr2_queried_hosts.add(rs.response_future._current_host) + + self.assertEqual(rr2_queried_hosts, rr1_queried_hosts) + + def test_ta_lbp(self): + """ + Test that execution profiles containing token aware LBP can be added + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result Queries can run + + @test_category config_profiles + """ + query = "select release_version from system.local" + ta1 = ExecutionProfile() + with TestCluster() as cluster: + session = cluster.connect() + cluster.add_execution_profile("ta1", ta1) + rs = session.execute(query, execution_profile='ta1') + + def test_clone_shared_lbp(self): + """ + Tests that profile load balancing policies are shared on clone + + Creates one LBP clones it, and ensures that the LBP is shared between + the two EP's + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result LBP is shared + + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1} + with TestCluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect(wait_for_all_pools=True) + self.assertGreater(len(cluster.metadata.all_hosts()), 1, "We only have one host connected at this point") + + rr1_clone = session.execution_profile_clone_update('rr1', row_factory=tuple_factory) + cluster.add_execution_profile("rr1_clone", rr1_clone) + rr1_queried_hosts = set() + rr1_clone_queried_hosts = set() + rs = session.execute(query, execution_profile='rr1') + rr1_queried_hosts.add(rs.response_future._current_host) + rs = session.execute(query, execution_profile='rr1_clone') + rr1_clone_queried_hosts.add(rs.response_future._current_host) + self.assertNotEqual(rr1_clone_queried_hosts, rr1_queried_hosts) + + def test_missing_exec_prof(self): + """ + Tests to verify that using an unknown profile raises a ValueError + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result ValueError + @test_category config_profiles + """ + query = "select release_version from system.local" + rr1 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) + exec_profiles = {'rr1': rr1, 'rr2': rr2} + with TestCluster(execution_profiles=exec_profiles) as cluster: + session = cluster.connect() + with self.assertRaises(ValueError): + session.execute(query, execution_profile='rr3') + + @local + def test_profile_pool_management(self): + """ + Tests that changes to execution profiles correctly impact our cluster's pooling + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result pools should be correctly updated as EP's are added and removed + + @test_category config_profiles + """ + + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.2" + ) + ) + with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + pools = session.get_pool_state() + # there are more hosts, but we connected to the ones in the lbp aggregate + self.assertGreater(len(cluster.metadata.all_hosts()), 2) + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2'))) + + # dynamically update pools on add + node3 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.3" + ) + ) + cluster.add_execution_profile('node3', node3) + pools = session.get_pool_state() + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2', '127.0.0.3'))) + + @local + def test_add_profile_timeout(self): + """ + Tests that EP Timeouts are honored. + + @since 3.5 + @jira_ticket PYTHON-569 + @expected_result EP timeouts should override defaults + + @test_category config_profiles + """ + max_retry_count = 10 + for i in range(max_retry_count): + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + pools = session.get_pool_state() + self.assertGreater(len(cluster.metadata.all_hosts()), 2) + self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) + + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in ["127.0.0.2", "127.0.0.3"] + ) + ) + + start = time.time() + try: + self.assertRaises(cassandra.OperationTimedOut, cluster.add_execution_profile, + 'profile_{0}'.format(i), + node2, pool_wait_timeout=sys.float_info.min) + break + except AssertionError: + end = time.time() + self.assertAlmostEqual(start, end, 1) + else: + raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) + + @notwindows + def test_execute_query_timeout(self): + with TestCluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + query = "SELECT * FROM system.local" + + # default is passed down + default_profile = cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT] + rs = session.execute(query) + self.assertEqual(rs.response_future.timeout, default_profile.request_timeout) + + # tiny timeout times out as expected + tmp_profile = copy(default_profile) + tmp_profile.request_timeout = sys.float_info.min + + max_retry_count = 10 + for _ in range(max_retry_count): + start = time.time() + try: + with self.assertRaises(cassandra.OperationTimedOut): + session.execute(query, execution_profile=tmp_profile) + break + except: + import traceback + traceback.print_exc() + end = time.time() + self.assertAlmostEqual(start, end, 1) + else: + raise Exception("session.execute didn't time out in {0} tries".format(max_retry_count)) + + def test_replicas_are_queried(self): + """ + Test that replicas are queried first for TokenAwarePolicy. A table with RF 1 + is created. All the queries should go to that replica when TokenAwarePolicy + is used. + Then using HostFilterPolicy the replica is excluded from the considered hosts. + By checking the trace we verify that there are no more replicas. + + @since 3.5 + @jira_ticket PYTHON-653 + @expected_result the replicas are queried for HostFilterPolicy + + @test_category metadata + """ + queried_hosts = set() + tap_profile = ExecutionProfile( + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()) + ) + with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: tap_profile}) as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute(''' + CREATE TABLE test1rf.table_with_big_key ( + k1 int, + k2 int, + k3 int, + k4 int, + PRIMARY KEY((k1, k2, k3), k4))''') + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for i in range(10): + result = session.execute(prepared, (i, i, i, i), trace=True) + trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + queried_hosts = self._assert_replica_queried(trace, only_replicas=True) + last_i = i + + hfp_profile = ExecutionProfile( + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), + predicate=lambda host: host.address != only_replica) + ) + only_replica = queried_hosts.pop() + log = logging.getLogger(__name__) + log.info("The only replica found was: {}".format(only_replica)) + available_hosts = [host for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] if host != only_replica] + with TestCluster(contact_points=available_hosts, + execution_profiles={EXEC_PROFILE_DEFAULT: hfp_profile}) as cluster: + + session = cluster.connect(wait_for_all_pools=True) + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for _ in range(10): + result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) + trace = result.response_future.get_query_trace(query_cl=ConsistencyLevel.ALL) + self._assert_replica_queried(trace, only_replicas=False) + + session.execute('''DROP TABLE test1rf.table_with_big_key''') + + @unittest.skip + @greaterthanorequalcass30 + @lessthanorequalcass40 + def test_compact_option(self): + """ + Test the driver can connect with the no_compact option and the results + are as expected. This test is very similar to the corresponding dtest + + @since 3.12 + @jira_ticket PYTHON-366 + @expected_result only one hosts' metadata will be populated + + @test_category connection + """ + nc_cluster = TestCluster(no_compact=True) + nc_session = nc_cluster.connect() + + cluster = TestCluster(no_compact=False) + session = cluster.connect() + + self.addCleanup(cluster.shutdown) + self.addCleanup(nc_cluster.shutdown) + + nc_session.set_keyspace("test3rf") + session.set_keyspace("test3rf") + + nc_session.execute( + "CREATE TABLE IF NOT EXISTS compact_table (k int PRIMARY KEY, v1 int, v2 int) WITH COMPACT STORAGE;") + + for i in range(1, 5): + nc_session.execute( + "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " + "({i}, 'a{i}', {i}, {i}, textAsBlob('b{i}'))".format(i=i)) + nc_session.execute( + "INSERT INTO compact_table (k, column1, v1, v2, value) VALUES " + "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i)) + + nc_results = nc_session.execute("SELECT * FROM compact_table") + self.assertEqual( + set(nc_results.current_rows), + {(1, u'a1', 11, 11, 'b1'), + (1, u'a11', 11, 11, 'b11'), + (2, u'a2', 22, 22, 'b2'), + (2, u'a22', 22, 22, 'b22'), + (3, u'a3', 33, 33, 'b3'), + (3, u'a33', 33, 33, 'b33'), + (4, u'a4', 44, 44, 'b4'), + (4, u'a44', 44, 44, 'b44')}) + + results = session.execute("SELECT * FROM compact_table") + self.assertEqual( + set(results.current_rows), + {(1, 11, 11), + (2, 22, 22), + (3, 33, 33), + (4, 44, 44)}) + + def _assert_replica_queried(self, trace, only_replicas=True): + queried_hosts = set() + for row in trace.events: + queried_hosts.add(row.source) + if only_replicas: + self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) + else: + self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) + return queried_hosts + + def _check_trace(self, trace): + self.assertIsNotNone(trace.request_type) + self.assertIsNotNone(trace.duration) + self.assertIsNotNone(trace.started_at) + self.assertIsNotNone(trace.coordinator) + self.assertIsNotNone(trace.events) + + +class LocalHostAdressTranslator(AddressTranslator): + + def __init__(self, addr_map=None): + self.addr_map = addr_map + + def translate(self, addr): + new_addr = self.addr_map.get(addr) + return new_addr + +@local +class TestAddressTranslation(unittest.TestCase): + + def test_address_translator_basic(self): + """ + Test host address translation + + Uses a custom Address Translator to map all ip back to one. + Validates AddressTranslator invocation by ensuring that only meta data associated with single + host is populated + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result only one hosts' metadata will be populated + + @test_category metadata + """ + lh_ad = LocalHostAdressTranslator({'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.1', '127.0.0.3': '127.0.0.1'}) + c = TestCluster(address_translator=lh_ad) + c.connect() + self.assertEqual(len(c.metadata.all_hosts()), 1) + c.shutdown() + + def test_address_translator_with_mixed_nodes(self): + """ + Test host address translation + + Uses a custom Address Translator to map ip's of non control_connection nodes to each other + Validates AddressTranslator invocation by ensuring that metadata for mapped hosts is also mapped + + @since 3.3 + @jira_ticket PYTHON-69 + @expected_result metadata for crossed hosts will also be crossed + + @test_category metadata + """ + adder_map = {'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.3', '127.0.0.3': '127.0.0.2'} + lh_ad = LocalHostAdressTranslator(adder_map) + c = TestCluster(address_translator=lh_ad) + c.connect() + for host in c.metadata.all_hosts(): + self.assertEqual(adder_map.get(host.address), host.broadcast_address) + c.shutdown() + +@local +class ContextManagementTest(unittest.TestCase): + load_balancing_policy = HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) + cluster_kwargs = {'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy= + load_balancing_policy)}, + 'schema_metadata_enabled': False, + 'token_metadata_enabled': False} + + def test_no_connect(self): + """ + Test cluster context without connecting. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result context should still be valid + + @test_category configuration + """ + with TestCluster() as cluster: + self.assertFalse(cluster.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_simple_nested(self): + """ + Test cluster and session contexts nested in one another. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result cluster/session should be crated and shutdown appropriately. + + @test_category configuration + """ + with TestCluster(**self.cluster_kwargs) as cluster: + with cluster.connect() as session: + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_cluster_no_session(self): + """ + Test cluster context without session context. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result Session should be created correctly. Cluster should shutdown outside of context + + @test_category configuration + """ + with TestCluster(**self.cluster_kwargs) as cluster: + session = cluster.connect() + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertTrue(cluster.is_shutdown) + + def test_session_no_cluster(self): + """ + Test session context without cluster context. + + @since 3.4 + @jira_ticket PYTHON-521 + @expected_result session should be created correctly. Session should shutdown correctly outside of context + + @test_category configuration + """ + cluster = TestCluster(**self.cluster_kwargs) + unmanaged_session = cluster.connect() + with cluster.connect() as session: + self.assertFalse(cluster.is_shutdown) + self.assertFalse(session.is_shutdown) + self.assertFalse(unmanaged_session.is_shutdown) + self.assertTrue(session.execute('select release_version from system.local')[0]) + self.assertTrue(session.is_shutdown) + self.assertFalse(cluster.is_shutdown) + self.assertFalse(unmanaged_session.is_shutdown) + unmanaged_session.shutdown() + self.assertTrue(unmanaged_session.is_shutdown) + self.assertFalse(cluster.is_shutdown) + cluster.shutdown() + self.assertTrue(cluster.is_shutdown) + + +class HostStateTest(unittest.TestCase): + + def test_down_event_with_active_connection(self): + """ + Test to ensure that on down calls to clusters with connections still active don't result in + a host being marked down. The second part of the test kills the connection then invokes + on_down, and ensures the state changes for host's metadata. + + @since 3.7 + @jira_ticket PYTHON-498 + @expected_result host should never be toggled down while a connection is active. + + @test_category connection + """ + with TestCluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + random_host = cluster.metadata.all_hosts()[0] + cluster.on_down(random_host, False) + for _ in range(10): + new_host = cluster.metadata.all_hosts()[0] + self.assertTrue(new_host.is_up, "Host was not up on iteration {0}".format(_)) + time.sleep(.01) + + pool = session._pools.get(random_host) + pool.shutdown() + cluster.on_down(random_host, False) + was_marked_down = False + for _ in range(20): + new_host = cluster.metadata.all_hosts()[0] + if not new_host.is_up: + was_marked_down = True + break + time.sleep(.01) + self.assertTrue(was_marked_down) + + +@local +class DontPrepareOnIgnoredHostsTest(unittest.TestCase): + ignored_addresses = ['127.0.0.3'] + ignore_node_3_policy = IgnoredHostPolicy(ignored_addresses) + + def test_prepare_on_ignored_hosts(self): + + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=self.ignore_node_3_policy)} + ) + session = cluster.connect() + cluster.reprepare_on_up, cluster.prepare_on_all_hosts = True, False + + hosts = cluster.metadata.all_hosts() + session.execute("CREATE KEYSPACE clustertests " + "WITH replication = " + "{'class': 'SimpleStrategy', 'replication_factor': '1'}") + session.execute("CREATE TABLE clustertests.tab (a text, PRIMARY KEY (a))") + # assign to an unused variable so cluster._prepared_statements retains + # reference + _ = session.prepare("INSERT INTO clustertests.tab (a) VALUES ('a')") # noqa + + cluster.connection_factory = Mock(wraps=cluster.connection_factory) + + unignored_address = '127.0.0.1' + unignored_host = next(h for h in hosts if h.address == unignored_address) + ignored_host = next(h for h in hosts if h.address in self.ignored_addresses) + unignored_host.is_up = ignored_host.is_up = False + + cluster.on_up(unignored_host) + cluster.on_up(ignored_host) + + # the length of mock_calls will vary, but all should use the unignored + # address + for c in cluster.connection_factory.mock_calls: + # PYTHON-1287 + # + # Cluster._prepare_all_queries() will call connection_factory _without_ the + # on_orphaned_stream_released arg introduced in commit + # 387150acc365b6cf1daaee58c62db13e4929099a. The reconnect handler for the + # downed node _will_ add this arg when it tries to rebuild it's conn pool, and + # whether this occurs while running this test amounts to a race condition. So + # to cover this case we assert one of two call styles here... the key is that + # the _only_ address we should see is the unignored_address. + self.assertTrue( \ + c == call(DefaultEndPoint(unignored_address)) or \ + c == call(DefaultEndPoint(unignored_address), on_orphaned_stream_released=ANY)) + cluster.shutdown() + + +@protocolv6 +class BetaProtocolTest(unittest.TestCase): + + @protocolv6 + def test_invalid_protocol_version_beta_option(self): + """ + Test cluster connection with protocol v6 and beta flag not set + + @since 3.7.0 + @jira_ticket PYTHON-614, PYTHON-1232 + @expected_result client shouldn't connect with V6 and no beta flag set + + @test_category connection + """ + + cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=False) + try: + with self.assertRaises(NoHostAvailable): + cluster.connect() + except Exception as e: + self.fail("Unexpected error encountered {0}".format(e.message)) + + @protocolv6 + def test_valid_protocol_version_beta_options_connect(self): + """ + Test cluster connection with protocol version 5 and beta flag set + + @since 3.7.0 + @jira_ticket PYTHON-614, PYTHON-1232 + @expected_result client should connect with protocol v6 and beta flag set. + + @test_category connection + """ + cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=True) + session = cluster.connect() + self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V6) + self.assertTrue(session.execute("select release_version from system.local")[0]) + cluster.shutdown() + + +class DeprecationWarningTest(unittest.TestCase): + def test_deprecation_warnings_legacy_parameters(self): + """ + Tests the deprecation warning has been added when using + legacy parameters + + @since 3.13 + @jira_ticket PYTHON-877 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + TestCluster(load_balancing_policy=RoundRobinPolicy()) + self.assertEqual(len(w), 1) + self.assertIn("Legacy execution parameters will be removed in 4.0. Consider using execution profiles.", + str(w[0].message)) + + def test_deprecation_warnings_meta_refreshed(self): + """ + Tests the deprecation warning has been added when enabling + metadata refreshment + + @since 3.13 + @jira_ticket PYTHON-890 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + cluster = TestCluster() + cluster.set_meta_refresh_enabled(True) + self.assertEqual(len(w), 1) + self.assertIn("Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0.", + str(w[0].message)) + + def test_deprecation_warning_default_consistency_level(self): + """ + Tests the deprecation warning has been added when enabling + session the default consistency level to session + + @since 3.14 + @jira_ticket PYTHON-935 + @expected_result the deprecation warning is emitted + + @test_category logs + """ + with warnings.catch_warnings(record=True) as w: + cluster = TestCluster() + session = cluster.connect() + session.default_consistency_level = ConsistencyLevel.ONE + self.assertEqual(len(w), 1) + self.assertIn("Setting the consistency level at the session level will be removed in 4.0", + str(w[0].message)) diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index 60243bb654..c935763bcb 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +15,18 @@ # limitations under the License. from itertools import cycle -from six import next import sys, logging, traceback from cassandra import InvalidRequest, ConsistencyLevel, ReadTimeout, WriteTimeout, OperationTimedOut, \ ReadFailure, WriteFailure -from cassandra.cluster import Cluster -from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args, ExecutionResult from cassandra.policies import HostDistance -from cassandra.query import tuple_factory, SimpleStatement - -from tests.integration import use_singledc, PROTOCOL_VERSION +from cassandra.query import dict_factory, tuple_factory, SimpleStatement -from six import next +from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest log = logging.getLogger(__name__) @@ -39,46 +35,52 @@ def setup_module(): use_singledc() +EXEC_PROFILE_DICT = "dict" + class ClusterTests(unittest.TestCase): @classmethod def setUpClass(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster( + execution_profiles = { + EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory), + EXEC_PROFILE_DICT: ExecutionProfile(row_factory=dict_factory) + } + ) if PROTOCOL_VERSION < 3: cls.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) cls.session = cls.cluster.connect() - cls.session.row_factory = tuple_factory @classmethod def tearDownClass(cls): cls.cluster.shutdown() - def execute_concurrent_helper(self, session, query, results_generator=False): + def execute_concurrent_helper(self, session, query, **kwargs): count = 0 while count < 100: try: - return execute_concurrent(session, query, results_generator=False) + return execute_concurrent(session, query, results_generator=False, **kwargs) except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb count += 1 raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) - def execute_concurrent_args_helper(self, session, query, params, results_generator=False): + def execute_concurrent_args_helper(self, session, query, params, results_generator=False, **kwargs): count = 0 while count < 100: try: - return execute_concurrent_with_args(session, query, params, results_generator=results_generator) + return execute_concurrent_with_args(session, query, params, results_generator=results_generator, **kwargs) except (ReadTimeout, WriteTimeout, OperationTimedOut, ReadFailure, WriteFailure): ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb raise RuntimeError("Failed to execute query after 100 attempts: {0}".format(query)) - def test_execute_concurrent(self): + def execute_concurrent_base(self, test_fn, validate_fn, zip_args=True): for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201): # write statement = SimpleStatement( @@ -87,7 +89,9 @@ def test_execute_concurrent(self): statements = cycle((statement, )) parameters = [(i, i) for i in range(num_statements)] - results = self.execute_concurrent_helper(self.session, list(zip(statements, parameters))) + results = \ + test_fn(self.session, list(zip(statements, parameters))) if zip_args else \ + test_fn(self.session, statement, parameters) self.assertEqual(num_statements, len(results)) for success, result in results: self.assertTrue(success) @@ -100,38 +104,43 @@ def test_execute_concurrent(self): statements = cycle((statement, )) parameters = [(i, ) for i in range(num_statements)] - results = self.execute_concurrent_helper(self.session, list(zip(statements, parameters))) + results = \ + test_fn(self.session, list(zip(statements, parameters))) if zip_args else \ + test_fn(self.session, statement, parameters) + validate_fn(num_statements, results) + + def execute_concurrent_validate_tuple(self, num_statements, results): self.assertEqual(num_statements, len(results)) self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) - def test_execute_concurrent_with_args(self): - for num_statements in (0, 1, 2, 7, 10, 99, 100, 101, 199, 200, 201): - statement = SimpleStatement( - "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", - consistency_level=ConsistencyLevel.QUORUM) - parameters = [(i, i) for i in range(num_statements)] - - results = self.execute_concurrent_args_helper(self.session, statement, parameters) + def execute_concurrent_validate_dict(self, num_statements, results): self.assertEqual(num_statements, len(results)) - for success, result in results: - self.assertTrue(success) - self.assertFalse(result) + self.assertEqual([(True, [{"v":i}]) for i in range(num_statements)], results) - # read - statement = SimpleStatement( - "SELECT v FROM test3rf.test WHERE k=%s", - consistency_level=ConsistencyLevel.QUORUM) - parameters = [(i, ) for i in range(num_statements)] + def test_execute_concurrent(self): + self.execute_concurrent_base(self.execute_concurrent_helper, \ + self.execute_concurrent_validate_tuple) - results = self.execute_concurrent_args_helper(self.session, statement, parameters) - self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) + def test_execute_concurrent_with_args(self): + self.execute_concurrent_base(self.execute_concurrent_args_helper, \ + self.execute_concurrent_validate_tuple, \ + zip_args=False) + + def test_execute_concurrent_with_execution_profile(self): + def run_fn(*args, **kwargs): + return self.execute_concurrent_helper(*args, execution_profile=EXEC_PROFILE_DICT, **kwargs) + self.execute_concurrent_base(run_fn, self.execute_concurrent_validate_dict) + + def test_execute_concurrent_with_args_and_execution_profile(self): + def run_fn(*args, **kwargs): + return self.execute_concurrent_args_helper(*args, execution_profile=EXEC_PROFILE_DICT, **kwargs) + self.execute_concurrent_base(run_fn, self.execute_concurrent_validate_dict, zip_args=False) def test_execute_concurrent_with_args_generator(self): """ Test to validate that generator based results are surfaced correctly - Repeatedly inserts data into a a table and attempts to query it. It then validates that the + Repeatedly inserts data into a table and attempts to query it. It then validates that the results are returned in the order expected @since 2.7.0 @@ -151,6 +160,12 @@ def test_execute_concurrent_with_args_generator(self): self.assertTrue(success) self.assertFalse(result) + results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + for result in results: + self.assertTrue(isinstance(result, ExecutionResult)) + self.assertTrue(result.success) + self.assertFalse(result.result_or_exc) + # read statement = SimpleStatement( "SELECT v FROM test3rf.test WHERE k=%s", @@ -158,6 +173,7 @@ def test_execute_concurrent_with_args_generator(self): parameters = [(i, ) for i in range(num_statements)] results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) + for i in range(num_statements): result = next(results) self.assertEqual((True, [(i,)]), result) @@ -198,7 +214,7 @@ def test_execute_concurrent_paged_result_generator(self): """ Test to validate that generator based results are surfaced correctly when paging is used - Inserts data into a a table and attempts to query it. It then validates that the + Inserts data into a table and attempts to query it. It then validates that the results are returned as expected (no order specified) @since 2.7.0 diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 51dd11a74b..e7177d8770 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,35 +14,189 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from functools import partial -from six.moves import range +from unittest.mock import patch +import logging import sys +import threading from threading import Thread, Event import time +from unittest import SkipTest -from cassandra import ConsistencyLevel, OperationTimedOut -from cassandra.cluster import NoHostAvailable -from cassandra.io.asyncorereactor import AsyncoreConnection +from cassandra import ConsistencyLevel, OperationTimedOut, DependencyException +from cassandra.cluster import NoHostAvailable, ConnectionShutdown, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.protocol import QueryMessage +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, HostStateListener +from cassandra.pool import HostConnectionPool from tests import is_monkey_patched -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, get_node, CASSANDRA_IP, local, \ + requiresmallclockgranularity, greaterthancass20, TestCluster + +try: + import cassandra.io.asyncorereactor + from cassandra.io.asyncorereactor import AsyncoreConnection +except DependencyException: + AsyncoreConnection = None try: from cassandra.io.libevreactor import LibevConnection -except ImportError: + import cassandra.io.libevreactor +except DependencyException: LibevConnection = None +log = logging.getLogger(__name__) + + def setup_module(): use_singledc() +class ConnectionTimeoutTest(unittest.TestCase): + + def setUp(self): + self.cluster = TestCluster(execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), predicate=lambda host: host.address == CASSANDRA_IP + ) + ) + }) + + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + @patch('cassandra.connection.Connection.max_in_flight', 2) + def test_in_flight_timeout(self): + """ + Test to ensure that connection id fetching will block when max_id is reached/ + + In previous versions of the driver this test will cause a + NoHostAvailable exception to be thrown, when the max_id is restricted + + @since 3.3 + @jira_ticket PYTHON-514 + @expected_result When many requests are run on a single node connection acquisition should block + until connection is available or the request times out. + + @test_category connection timeout + """ + futures = [] + query = '''SELECT * FROM system.local''' + for _ in range(100): + futures.append(self.session.execute_async(query)) + + for future in futures: + future.result() + + +class TestHostListener(HostStateListener): + host_down = None + + def on_down(self, host): + self.host_down = True + + def on_up(self, host): + self.host_down = False + + +class HeartbeatTest(unittest.TestCase): + """ + Test to validate failing a heartbeat check doesn't mark a host as down + + @since 3.3 + @jira_ticket PYTHON-286 + @expected_result host should be marked down when heartbeat fails. This + happens after PYTHON-734 + + @test_category connection heartbeat + """ + + def setUp(self): + self.cluster = TestCluster(idle_heartbeat_interval=1) + self.session = self.cluster.connect(wait_for_all_pools=True) + + def tearDown(self): + self.cluster.shutdown() + + @local + @greaterthancass20 + def test_heart_beat_timeout(self): + # Setup a host listener to ensure the nodes don't go down + test_listener = TestHostListener() + host = "127.0.0.1:9042" + node = get_node(1) + initial_connections = self.fetch_connections(host, self.cluster) + self.assertNotEqual(len(initial_connections), 0) + self.cluster.register_listener(test_listener) + # Pause the node + try: + node.pause() + # Wait for connections associated with this host go away + self.wait_for_no_connections(host, self.cluster) + + # Wait to seconds for the driver to be notified + time.sleep(2) + self.assertTrue(test_listener.host_down) + # Resume paused node + finally: + node.resume() + # Run a query to ensure connections are re-established + current_host = "" + count = 0 + while current_host != host and count < 100: + rs = self.session.execute_async("SELECT * FROM system.local", trace=False) + rs.result() + current_host = str(rs._current_host) + count += 1 + time.sleep(.1) + self.assertLess(count, 100, "Never connected to the first node") + new_connections = self.wait_for_connections(host, self.cluster) + self.assertFalse(test_listener.host_down) + # Make sure underlying new connections don't match previous ones + for connection in initial_connections: + self.assertFalse(connection in new_connections) + + def fetch_connections(self, host, cluster): + # Given a cluster object and host grab all connection associated with that host + connections = [] + holders = cluster.get_connection_holders() + for conn in holders: + if host == str(getattr(conn, 'host', '')): + if isinstance(conn, HostConnectionPool): + if conn._connections is not None and len(conn._connections) > 0: + connections.append(conn._connections) + else: + if conn._connection is not None: + connections.append(conn._connection) + return connections + + def wait_for_connections(self, host, cluster): + retry = 0 + while(retry < 300): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is not 0: + return connections + time.sleep(.1) + self.fail("No new connections found") + + def wait_for_no_connections(self, host, cluster): + retry = 0 + while(retry < 100): + retry += 1 + connections = self.fetch_connections(host, cluster) + if len(connections) is 0: + return + time.sleep(.5) + self.fail("Connections never cleared") + + class ConnectionTests(object): klass = None @@ -60,9 +216,15 @@ def get_connection(self, timeout=5): e = None for i in range(5): try: - conn = self.klass.factory(host='127.0.0.1', timeout=timeout, protocol_version=PROTOCOL_VERSION) + contact_point = CASSANDRA_IP + conn = self.klass.factory( + endpoint=contact_point, + timeout=timeout, + protocol_version=TestCluster.DEFAULT_PROTOCOL_VERSION, + allow_beta_protocol_version=TestCluster.DEFAULT_ALLOW_BETA + ) break - except (OperationTimedOut, NoHostAvailable) as e: + except (OperationTimedOut, NoHostAvailable, ConnectionShutdown) as e: continue if conn: @@ -226,6 +388,7 @@ def send_msgs(conn, event): for t in threads: t.join() + @requiresmallclockgranularity def test_connect_timeout(self): # Underlying socket implementations don't always throw a socket timeout even with min float # This can be timing sensitive, added retry to ensure failure occurs if it can @@ -234,7 +397,8 @@ def test_connect_timeout(self): for i in range(max_retry_count): start = time.time() try: - self.get_connection(timeout=sys.float_info.min) + conn = self.get_connection(timeout=sys.float_info.min) + conn.close() except Exception as e: end = time.time() self.assertAlmostEqual(start, end, 1) @@ -242,20 +406,58 @@ def test_connect_timeout(self): break self.assertTrue(exception_thrown) + def test_subclasses_share_loop(self): + + if self.klass not in (AsyncoreConnection, LibevConnection): + raise SkipTest + + class C1(self.klass): + pass + + class C2(self.klass): + pass + + clusterC1 = TestCluster(connection_class=C1) + clusterC1.connect(wait_for_all_pools=True) + + clusterC2 = TestCluster(connection_class=C2) + clusterC2.connect(wait_for_all_pools=True) + self.addCleanup(clusterC1.shutdown) + self.addCleanup(clusterC2.shutdown) + + self.assertEqual(len(get_eventloop_threads(self.event_loop_name)), 1) + + +def get_eventloop_threads(name): + all_threads = list(threading.enumerate()) + log.debug('all threads: {}'.format(all_threads)) + log.debug('all names: {}'.format([thread.name for thread in all_threads])) + event_loops_threads = [thread for thread in all_threads if name == thread.name] + + return event_loops_threads + class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase): klass = AsyncoreConnection + event_loop_name = "asyncore_cassandra_driver_event_loop" def setUp(self): if is_monkey_patched(): raise unittest.SkipTest("Can't test asyncore with monkey patching") + if AsyncoreConnection is None: + raise unittest.SkipTest('Unable to import asyncore module') ConnectionTests.setUp(self) + def clean_global_loop(self): + cassandra.io.asyncorereactor._global_loop._cleanup() + cassandra.io.asyncorereactor._global_loop = None + class LibevConnectionTests(ConnectionTests, unittest.TestCase): klass = LibevConnection + event_loop_name = "event_loop" def setUp(self): if is_monkey_patched(): @@ -264,3 +466,7 @@ def setUp(self): raise unittest.SkipTest( 'libev does not appear to be installed properly') ConnectionTests.setUp(self) + + def clean_global_loop(self): + cassandra.io.libevreactor._global_loop._cleanup() + cassandra.io.libevreactor._global_loop = None diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index 07c5bac992..9d579476d2 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -1,4 +1,4 @@ -# Copyright 2013-2016 DataStax, Inc. +# Copyright DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +14,13 @@ # # # +from cassandra import InvalidRequest -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from cassandra.cluster import Cluster from cassandra.protocol import ConfigurationException -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster, greaterthanorequalcass40, notdse from tests.integration.datatype_utils import update_datatypes @@ -38,13 +35,12 @@ def setUp(self): raise unittest.SkipTest( "Native protocol 3,0+ is required for UDTs using %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect() + self.cluster = TestCluster() def tearDown(self): try: self.session.execute("DROP KEYSPACE keyspacetodrop ") - except (ConfigurationException): + except (ConfigurationException, InvalidRequest): # we already removed the keyspace. pass self.cluster.shutdown() @@ -65,6 +61,7 @@ def test_drop_keyspace(self): @test_category connection """ + self.session = self.cluster.connect() self.session.execute(""" CREATE KEYSPACE keyspacetodrop WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1' } @@ -76,3 +73,56 @@ def test_drop_keyspace(self): self.session.execute("DROP KEYSPACE keyspacetodrop") cc_id_post_drop = id(self.cluster.control_connection._connection) self.assertEqual(cc_id_post_drop, cc_id_pre_drop) + + def test_get_control_connection_host(self): + """ + Test to validate Cluster.get_control_connection_host() metadata + + @since 3.5.0 + @jira_ticket PYTHON-583 + @expected_result the control connection metadata should accurately reflect cluster state. + + @test_category metadata + """ + + host = self.cluster.get_control_connection_host() + self.assertEqual(host, None) + + self.session = self.cluster.connect() + cc_host = self.cluster.control_connection._connection.host + + host = self.cluster.get_control_connection_host() + self.assertEqual(host.address, cc_host) + self.assertEqual(host.is_up, True) + + # reconnect and make sure that the new host is reflected correctly + self.cluster.control_connection._reconnect() + new_host = self.cluster.get_control_connection_host() + self.assertNotEqual(host, new_host) + + @notdse + @greaterthanorequalcass40 + def test_control_connection_port_discovery(self): + """ + Test to validate that the correct port is discovered when peersV2 is used (C* 4.0+). + + Unit tests already validate that the port can be picked up (or not) from the query. This validates + it picks up the correct port from a real server and is able to connect. + """ + self.cluster = TestCluster() + + host = self.cluster.get_control_connection_host() + self.assertEqual(host, None) + + self.session = self.cluster.connect() + cc_endpoint = self.cluster.control_connection._connection.endpoint + + host = self.cluster.get_control_connection_host() + self.assertEqual(host.endpoint, cc_endpoint) + self.assertEqual(host.is_up, True) + hosts = self.cluster.metadata.all_hosts() + self.assertEqual(3, len(hosts)) + + for host in hosts: + self.assertEqual(9042, host.broadcast_rpc_port) + self.assertEqual(7000, host.broadcast_port) diff --git a/tests/integration/standard/test_custom_cluster.py b/tests/integration/standard/test_custom_cluster.py new file mode 100644 index 0000000000..bb3f716984 --- /dev/null +++ b/tests/integration/standard/test_custom_cluster.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cluster import NoHostAvailable +from tests.integration import use_singledc, get_cluster, remove_cluster, local, TestCluster +from tests.util import wait_until, wait_until_not_raised + +import unittest + + +def setup_module(): + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + config_options = {'native_transport_port': 9046} + ccm_cluster.set_configuration_options(config_options) + # can't use wait_for_binary_proto cause ccm tries on port 9042 + ccm_cluster.start(wait_for_binary_proto=False) + # wait until all nodes are up + wait_until_not_raised(lambda: TestCluster(contact_points=['127.0.0.1'], port=9046).connect().shutdown(), 1, 20) + wait_until_not_raised(lambda: TestCluster(contact_points=['127.0.0.2'], port=9046).connect().shutdown(), 1, 20) + wait_until_not_raised(lambda: TestCluster(contact_points=['127.0.0.3'], port=9046).connect().shutdown(), 1, 20) + + +def teardown_module(): + remove_cluster() + + +class CustomClusterTests(unittest.TestCase): + + @local + def test_connection_honor_cluster_port(self): + """ + Test that the initial contact point and discovered nodes honor + the cluster port on new connection. + + All hosts should be marked as up and we should be able to execute queries on it. + """ + cluster = TestCluster() + with self.assertRaises(NoHostAvailable): + cluster.connect() # should fail on port 9042 + + cluster = TestCluster(port=9046) + session = cluster.connect(wait_for_all_pools=True) + + wait_until(lambda: len(cluster.metadata.all_hosts()) == 3, 1, 5) + for host in cluster.metadata.all_hosts(): + self.assertTrue(host.is_up) + session.execute("select * from system.local", host=host) diff --git a/tests/integration/standard/test_custom_payload.py b/tests/integration/standard/test_custom_payload.py index 3d4b849661..374bee9046 100644 --- a/tests/integration/standard/test_custom_payload.py +++ b/tests/integration/standard/test_custom_payload.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,21 +15,21 @@ # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest - -import six +import unittest from cassandra.query import (SimpleStatement, BatchStatement, BatchType) -from cassandra.cluster import Cluster -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, local, TestCluster + def setup_module(): use_singledc() +#These test rely on the custom payload being returned but by default C* +#ignores all the payloads. + + +@local class CustomPayloadTests(unittest.TestCase): def setUp(self): @@ -35,7 +37,7 @@ def setUp(self): raise unittest.SkipTest( "Native protocol 4,0+ is required for custom payloads, currently using %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() self.session = self.cluster.connect() def tearDown(self): @@ -138,16 +140,16 @@ def validate_various_custom_payloads(self, statement): # Long key value pair key_value = "x" * 10 - custom_payload = {key_value: six.b(key_value)} + custom_payload = {key_value: key_value.encode()} self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) # Max supported value key pairs according C* binary protocol v4 should be 65534 (unsigned short max value) for i in range(65534): - custom_payload[str(i)] = six.b('x') + custom_payload[str(i)] = b'x' self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) # Add one custom payload to this is too many key value pairs and should fail - custom_payload[str(65535)] = six.b('x') + custom_payload[str(65535)] = b'x' with self.assertRaises(ValueError): self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 63a8380902..5c75684787 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,20 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest + +from cassandra.protocol import ProtocolHandler, ResultMessage, QueryMessage, UUIDType, read_int +from cassandra.query import tuple_factory, SimpleStatement +from cassandra.cluster import (ResponseFuture, ExecutionProfile, EXEC_PROFILE_DEFAULT, + ContinuousPagingOptions, NoHostAvailable) +from cassandra import ProtocolVersion, ConsistencyLevel -from cassandra.protocol import ProtocolHandler, ResultMessage, UUIDType, read_int, EventMessage -from cassandra.query import tuple_factory -from cassandra.cluster import Cluster -from tests.integration import use_singledc, PROTOCOL_VERSION, drop_keyspace_shutdown_cluster +from tests.integration import use_singledc, drop_keyspace_shutdown_cluster, \ + greaterthanorequalcass30, execute_with_long_wait_retry, greaterthanorequaldse51, greaterthanorequalcass3_10, \ + TestCluster, greaterthanorequalcass40, requirecassandra from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params -from six import binary_type import uuid +from unittest import mock def setup_module(): @@ -37,7 +41,7 @@ class CustomProtocolHandlerTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE custserdes WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}") cls.session.set_keyspace("custserdes") @@ -62,19 +66,20 @@ def test_custom_raw_uuid_row_results(self): """ # Ensure that we get normal uuid back first - session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes") - session.row_factory = tuple_factory + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + ) + session = cluster.connect(keyspace="custserdes") + result = session.execute("SELECT schema_version FROM system.local") uuid_type = result[0][0] self.assertEqual(type(uuid_type), uuid.UUID) - # use our custom protocol handlder - + # use our custom protocol handler session.client_protocol_handler = CustomTestRawRowType - session.row_factory = tuple_factory result_set = session.execute("SELECT schema_version FROM system.local") raw_value = result_set[0][0] - self.assertTrue(isinstance(raw_value, binary_type)) + self.assertTrue(isinstance(raw_value, bytes)) self.assertEqual(len(raw_value), 16) # Ensure that we get normal uuid back when we re-connect @@ -82,7 +87,7 @@ def test_custom_raw_uuid_row_results(self): result_set = session.execute("SELECT schema_version FROM system.local") uuid_type = result_set[0][0] self.assertEqual(type(uuid_type), uuid.UUID) - session.shutdown() + cluster.shutdown() def test_custom_raw_row_results_all_types(self): """ @@ -99,9 +104,11 @@ def test_custom_raw_row_results_all_types(self): @test_category data_types:serialization """ # Connect using a custom protocol handler that tracks the various types the result message is used with. - session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes") + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + ) + session = cluster.connect(keyspace="custserdes") session.client_protocol_handler = CustomProtocolHandlerResultMessageTracked - session.row_factory = tuple_factory colnames = create_table_with_all_types("alltypes", session, 1) columns_string = ", ".join(colnames) @@ -113,7 +120,137 @@ def test_custom_raw_row_results_all_types(self): self.assertEqual(actual, expected) # Ensure we have covered the various primitive types self.assertEqual(len(CustomResultMessageTracked.checked_rev_row_set), len(PRIMITIVE_DATATYPES)-1) - session.shutdown() + cluster.shutdown() + + @requirecassandra + @greaterthanorequalcass40 + def test_protocol_divergence_v5_fail_by_continuous_paging(self): + """ + Test to validate that V5 and DSE_V1 diverge. ContinuousPagingOptions is not supported by V5 + + @since DSE 2.0b3 GRAPH 1.0b1 + @jira_ticket PYTHON-694 + @expected_result NoHostAvailable will be risen when the continuous_paging_options parameter is set + + @test_category connection + """ + cluster = TestCluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + session = cluster.connect() + + max_pages = 4 + max_pages_per_second = 3 + continuous_paging_options = ContinuousPagingOptions(max_pages=max_pages, + max_pages_per_second=max_pages_per_second) + + future = self._send_query_message(session, timeout=session.default_timeout, + consistency_level=ConsistencyLevel.ONE, + continuous_paging_options=continuous_paging_options) + + # This should raise NoHostAvailable because continuous paging is not supported under ProtocolVersion.DSE_V1 + with self.assertRaises(NoHostAvailable) as context: + future.result() + self.assertIn("Continuous paging may only be used with protocol version ProtocolVersion.DSE_V1 or higher", + str(context.exception)) + + cluster.shutdown() + + @greaterthanorequalcass30 + def test_protocol_divergence_v4_fail_by_flag_uses_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is not treated correctly in V4 if the flags are + written using write_uint instead of write_int + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the fetch_size=1 parameter will be ignored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V4, uses_int_query_flag=False, + int_flag=True) + + @requirecassandra + @greaterthanorequalcass40 + def test_protocol_v5_uses_flag_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for V5 + + @jira_ticket PYTHON-694 + @expected_result the fetch_size=1 parameter will be honored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V5, uses_int_query_flag=True, beta=True, + int_flag=True) + + @greaterthanorequaldse51 + def test_protocol_dsev1_uses_flag_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for DSE_V1 + + @jira_ticket PYTHON-694 + @expected_result the fetch_size=1 parameter will be honored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.DSE_V1, uses_int_query_flag=True, + int_flag=True) + + @requirecassandra + @greaterthanorequalcass40 + def test_protocol_divergence_v5_fail_by_flag_uses_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for V5 + + @jira_ticket PYTHON-694 + @expected_result the fetch_size=1 parameter will be honored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V5, uses_int_query_flag=False, beta=True, + int_flag=False) + + @greaterthanorequaldse51 + def test_protocol_divergence_dsev1_fail_by_flag_uses_int(self): + """ + Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for DSE_V1 + + @jira_ticket PYTHON-694 + @expected_result the fetch_size=1 parameter will be honored + + @test_category connection + """ + self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.DSE_V1, uses_int_query_flag=False, + int_flag=False) + + def _send_query_message(self, session, timeout, **kwargs): + query = "SELECT * FROM test3rf.test" + message = QueryMessage(query=query, **kwargs) + future = ResponseFuture(session, message, query=None, timeout=timeout) + future.send_request() + return future + + def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_flag, int_flag = True, beta=False): + cluster = TestCluster(protocol_version=version, allow_beta_protocol_version=beta) + session = cluster.connect() + + query_one = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (1, 1)") + query_two = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (2, 2)") + + execute_with_long_wait_retry(session, query_one) + execute_with_long_wait_retry(session, query_two) + + with mock.patch('cassandra.protocol.ProtocolVersion.uses_int_query_flags', new=mock.Mock(return_value=int_flag)): + future = self._send_query_message(session, 10, + consistency_level=ConsistencyLevel.ONE, fetch_size=1) + + response = future.result() + + # This means the flag are not handled as they are meant by the server if uses_int=False + self.assertEqual(response.has_more_pages, uses_int_query_flag) + + execute_with_long_wait_retry(session, SimpleStatement("TRUNCATE test3rf.test")) + cluster.shutdown() class CustomResultMessageRaw(ResultMessage): @@ -125,18 +262,18 @@ class CustomResultMessageRaw(ResultMessage): my_type_codes[0xc] = UUIDType type_codes = my_type_codes - @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): - paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + self.recv_results_metadata(f, user_type_map) + column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) - rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] - coltypes = [c[3] for c in column_metadata] - return (paging_state, (coltypes, rows)) + self.parsed_rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + self.column_names = [c[2] for c in column_metadata] + self.column_types = [c[3] for c in column_metadata] class CustomTestRawRowType(ProtocolHandler): """ - This is the a custom protocol handler that will substitute the the + This is a custom protocol handler that will substitute the customResultMesageRowRaw Result message for our own implementation """ my_opcodes = ProtocolHandler.message_types_by_opcode.copy() @@ -146,7 +283,7 @@ class CustomTestRawRowType(ProtocolHandler): class CustomResultMessageTracked(ResultMessage): """ - This is a custom Result Message that is use to track what primitive types + This is a custom Result Message that is used to track what primitive types have been processed when it receives results """ my_type_codes = ResultMessage.type_codes.copy() @@ -154,28 +291,25 @@ class CustomResultMessageTracked(ResultMessage): type_codes = my_type_codes checked_rev_row_set = set() - @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): - paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + self.recv_results_metadata(f, user_type_map) + column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) - rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] - colnames = [c[2] for c in column_metadata] - coltypes = [c[3] for c in column_metadata] - cls.checked_rev_row_set.update(coltypes) - parsed_rows = [ + rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] + self.column_names = [c[2] for c in column_metadata] + self.column_types = [c[3] for c in column_metadata] + self.checked_rev_row_set.update(self.column_types) + self.parsed_rows = [ tuple(ctype.from_binary(val, protocol_version) - for ctype, val in zip(coltypes, row)) + for ctype, val in zip(self.column_types, row)) for row in rows] - return (paging_state, (colnames, parsed_rows)) class CustomProtocolHandlerResultMessageTracked(ProtocolHandler): """ - This is the a custom protocol handler that will substitute the the + This is a custom protocol handler that will substitute the CustomTestRawRowTypeTracked Result message for our own implementation """ my_opcodes = ProtocolHandler.message_types_by_opcode.copy() my_opcodes[CustomResultMessageTracked.opcode] = CustomResultMessageTracked message_types_by_opcode = my_opcodes - - diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index dc24d0a3e6..83d39407c4 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -2,20 +2,21 @@ # Based on test_custom_protocol_handler.py -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest -from cassandra.query import tuple_factory -from cassandra.cluster import Cluster -from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler +from itertools import count -from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler +from cassandra.query import tuple_factory +from tests import VERIFY_CYTHON +from tests.integration import use_singledc, notprotocolv1, \ + drop_keyspace_shutdown_cluster, BasicSharedKeyspaceUnitTestCase, greaterthancass21, TestCluster from tests.integration.datatype_utils import update_datatypes from tests.integration.standard.utils import ( create_table_with_all_types, get_all_primitive_params, get_primitive_datatypes) - from tests.unit.cython.utils import cythontest, numpytest @@ -30,7 +31,7 @@ class CythonProtocolHandlerTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE testspace WITH replication = " "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") @@ -39,7 +40,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - drop_keyspace_shutdown_cluster("testspace", cls.session, cls.session) + drop_keyspace_shutdown_cluster("testspace", cls.session, cls.cluster) @cythontest def test_cython_parser(self): @@ -55,16 +56,16 @@ def test_cython_lazy_parser(self): """ verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) - @notprotocolv1 @numpytest def test_cython_lazy_results_paged(self): """ Test Cython-based parser that returns an iterator, over multiple pages """ # arrays = { 'a': arr1, 'b': arr2, ... } - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + ) session = cluster.connect(keyspace="testspace") - session.row_factory = tuple_factory session.client_protocol_handler = LazyProtocolHandler session.default_fetch_size = 2 @@ -95,9 +96,10 @@ def test_numpy_results_paged(self): Test Numpy-based parser that returns a NumPy array """ # arrays = { 'a': arr1, 'b': arr2, ... } - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + ) session = cluster.connect(keyspace="testspace") - session.row_factory = tuple_factory session.client_protocol_handler = NumpyProtocolHandler session.default_fetch_size = 2 @@ -123,6 +125,20 @@ def test_numpy_results_paged(self): cluster.shutdown() + @numpytest + def test_cython_numpy_are_installed_valid(self): + """ + Test to validate that cython and numpy are installed correctly + @since 3.3.0 + @jira_ticket PYTHON-543 + @expected_result Cython and Numpy should be present + + @test_category configuration + """ + if VERIFY_CYTHON: + self.assertTrue(HAVE_CYTHON) + self.assertTrue(HAVE_NUMPY) + def _verify_numpy_page(self, page): colnames = self.colnames datatypes = get_primitive_datatypes() @@ -163,15 +179,16 @@ def get_data(protocol_handler): """ Get data from the test table. """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)} + ) session = cluster.connect(keyspace="testspace") # use our custom protocol handler session.client_protocol_handler = protocol_handler - session.row_factory = tuple_factory results = session.execute("SELECT * FROM test_table") - session.shutdown() + cluster.shutdown() return results @@ -188,3 +205,56 @@ def verify_iterator_data(assertEqual, results): for expected, actual in zip(params, result): assertEqual(actual, expected) return count + + +class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + cls.common_setup(1, execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=tuple_factory)}) + + @numpytest + @greaterthancass21 + def test_null_types(self): + """ + Test to validate that the numpy protocol handler can deal with null values. + @since 3.3.0 + - updated 3.6.0: now numeric types used masked array + @jira_ticket PYTHON-550 + @expected_result Numpy can handle non mapped types' null values. + + @test_category data_types:serialization + """ + s = self.session + s.client_protocol_handler = NumpyProtocolHandler + + table = "%s.%s" % (self.keyspace_name, self.function_table_name) + create_table_with_all_types(table, s, 10) + + begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1 + keys_null = range(begin_unset, begin_unset + 10) + + # scatter some empty rows in here + insert = "insert into %s (primkey) values (%%s)" % (table,) + execute_concurrent_with_args(s, insert, ((k,) for k in keys_null)) + + result = s.execute("select * from %s" % (table,))[0] + + from numpy.ma import masked, MaskedArray + result_keys = result.pop('primkey') + mapped_index = [v[1] for v in sorted(zip(result_keys, count()))] + + had_masked = had_none = False + for col_array in result.values(): + # these have to be different branches (as opposed to comparing against an 'unset value') + # because None and `masked` have different identity and equals semantics + if isinstance(col_array, MaskedArray): + had_masked = True + [self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]] + [self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]] + else: + had_none = True + [self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]] + [self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]] + self.assertTrue(had_masked) + self.assertTrue(had_none) diff --git a/tests/integration/standard/test_dse.py b/tests/integration/standard/test_dse.py new file mode 100644 index 0000000000..0a339b6b3d --- /dev/null +++ b/tests/integration/standard/test_dse.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from packaging.version import Version + +from tests import notwindows +from tests.unit.cython.utils import notcython +from tests.integration import (execute_until_pass, + execute_with_long_wait_retry, use_cluster, TestCluster) + +import unittest + + +CCM_IS_DSE = (os.environ.get('CCM_IS_DSE', None) == 'true') + + +@unittest.skipIf(os.environ.get('CCM_ARGS', None), 'environment has custom CCM_ARGS; skipping') +@notwindows +@notcython # no need to double up on this test; also __default__ setting doesn't work +class DseCCMClusterTest(unittest.TestCase): + """ + This class can be executed setting the DSE_VERSION variable, for example: + DSE_VERSION=5.1.4 python2.7 -m nose tests/integration/standard/test_dse.py + If CASSANDRA_VERSION is set instead, it will be converted to the corresponding DSE_VERSION + """ + + def test_dse_5x(self): + self._test_basic(Version('5.1.10')) + + def test_dse_60(self): + self._test_basic(Version('6.0.2')) + + @unittest.skipUnless(CCM_IS_DSE, 'DSE version unavailable') + def test_dse_67(self): + self._test_basic(Version('6.7.0')) + + def _test_basic(self, dse_version): + """ + Test basic connection and usage + """ + cluster_name = '{}-{}'.format( + self.__class__.__name__, dse_version.base_version.replace('.', '_') + ) + use_cluster(cluster_name=cluster_name, nodes=[3], dse_options={}) + + cluster = TestCluster() + session = cluster.connect() + result = execute_until_pass( + session, + """ + CREATE KEYSPACE clustertests + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} + """) + self.assertFalse(result) + + result = execute_with_long_wait_retry( + session, + """ + CREATE TABLE clustertests.cf0 ( + a text, + b text, + c text, + PRIMARY KEY (a, b) + ) + """) + self.assertFalse(result) + + result = session.execute( + """ + INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') + """) + self.assertFalse(result) + + result = session.execute("SELECT * FROM clustertests.cf0") + self.assertEqual([('a', 'b', 'c')], result) + + execute_with_long_wait_retry(session, "DROP KEYSPACE clustertests") + + cluster.shutdown() diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 583943bcc3..8f7ba04883 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,41 +14,151 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest +from collections import defaultdict import difflib -import six +import logging import sys import time -from mock import Mock, patch +import os +from packaging.version import Version +from unittest.mock import Mock, patch from cassandra import AlreadyExists, SignatureDescriptor, UserFunctionDescriptor, UserAggregateDescriptor -from cassandra.cluster import Cluster from cassandra.encoder import Encoder -from cassandra.metadata import (Metadata, KeyspaceMetadata, IndexMetadata, - Token, MD5Token, TokenMap, murmur3, Function, Aggregate, protect_name, protect_names, - get_schema_parser) -from cassandra.policies import SimpleConvictionPolicy -from cassandra.pool import Host +from cassandra.metadata import (IndexMetadata, Token, murmur3, Function, Aggregate, protect_name, protect_names, + RegisteredTableExtension, _RegisteredExtensionType, get_schema_parser, + group_keys_by_replica, NO_VALID_REPLICA) +from cassandra.util import SortedSet + +from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, execute_until_pass, + BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, + BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, + greaterthanorequaldse51, greaterthanorequalcass30, lessthancass30, local, + get_supported_protocol_versions, greaterthancass20, + greaterthancass21, assert_startswith, greaterthanorequalcass40, + greaterthanorequaldse67, lessthancass40, + TestCluster, DSE_VERSION, HCD_VERSION) + -from tests.integration import get_cluster, use_singledc, PROTOCOL_VERSION, get_server_versions, execute_until_pass, \ - BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster +log = logging.getLogger(__name__) -from tests.unit.cython.utils import notcython def setup_module(): use_singledc() - global CASS_SERVER_VERSION - CASS_SERVER_VERSION = get_server_versions()[0] + + +class HostMetaDataTests(BasicExistingKeyspaceUnitTestCase): + @local + def test_host_addresses(self): + """ + Check to ensure that the broadcast_address, broadcast_rpc_address, + listen adresss, ports and host are is populated correctly. + + @since 3.3 + @jira_ticket PYTHON-332 + @expected_result They are populated for C*> 2.1.6, 2.2.0 + + @test_category metadata + """ + # All nodes should have the broadcast_address, rpc_address and host_id set + for host in self.cluster.metadata.all_hosts(): + self.assertIsNotNone(host.broadcast_address) + self.assertIsNotNone(host.broadcast_rpc_address) + self.assertIsNotNone(host.host_id) + + if not DSE_VERSION and CASSANDRA_VERSION >= Version('4-a'): + self.assertIsNotNone(host.broadcast_port) + self.assertIsNotNone(host.broadcast_rpc_port) + + con = self.cluster.control_connection.get_connections()[0] + local_host = con.host + + # The control connection node should have the listen address set. + listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] + self.assertTrue(local_host in listen_addrs) + + # The control connection node should have the broadcast_rpc_address set. + rpc_addrs = [host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts()] + self.assertTrue(local_host in rpc_addrs) + + @unittest.skipUnless( + os.getenv('MAPPED_CASSANDRA_VERSION', None) is not None, + "Don't check the host version for test-dse") + def test_host_release_version(self): + """ + Checks the hosts release version and validates that it is equal to the + Cassandra version we are using in our test harness. + + @since 3.3 + @jira_ticket PYTHON-301 + @expected_result host.release version should match our specified Cassandra version. + + @test_category metadata + """ + for host in self.cluster.metadata.all_hosts(): + assert_startswith(host.release_version, CASSANDRA_VERSION.base_version) + + + +@local +class MetaDataRemovalTest(unittest.TestCase): + + def setUp(self): + self.cluster = TestCluster(contact_points=['127.0.0.1', '127.0.0.2', '127.0.0.3', '126.0.0.186']) + self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_bad_contact_point(self): + """ + Checks to ensure that hosts that are not resolvable are excluded from the contact point list. + + @since 3.6 + @jira_ticket PYTHON-549 + @expected_result Invalid hosts on the contact list should be excluded + + @test_category metadata + """ + self.assertEqual(len(self.cluster.metadata.all_hosts()), 3) class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): - def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None, compact=False): + def test_schema_metadata_disable(self): + """ + Checks to ensure that schema metadata_enabled, and token_metadata_enabled + flags work correctly. + + @since 3.3 + @jira_ticket PYTHON-327 + @expected_result schema metadata will not be populated when schema_metadata_enabled is fause + token_metadata will be missing when token_metadata is set to false + + @test_category metadata + """ + # Validate metadata is missing where appropriate + no_schema = TestCluster(schema_metadata_enabled=False) + no_schema_session = no_schema.connect() + self.assertEqual(len(no_schema.metadata.keyspaces), 0) + self.assertEqual(no_schema.metadata.export_schema_as_string(), '') + no_token = TestCluster(token_metadata_enabled=False) + no_token_session = no_token.connect() + self.assertEqual(len(no_token.metadata.token_map.token_to_host_owner), 0) + + # Do a simple query to ensure queries are working + query = "SELECT * FROM system.local" + no_schema_rs = no_schema_session.execute(query) + no_token_rs = no_token_session.execute(query) + self.assertIsNotNone(no_schema_rs[0]) + self.assertIsNotNone(no_token_rs[0]) + no_schema.shutdown() + no_token.shutdown() + + def make_create_statement(self, partition_cols, clustering_cols=None, other_cols=None): clustering_cols = clustering_cols or [] other_cols = other_cols or [] @@ -74,8 +186,6 @@ def make_create_statement(self, partition_cols, clustering_cols=None, other_cols statement += ")" statement += ")" - if compact: - statement += " WITH COMPACT STORAGE" return statement @@ -120,7 +230,13 @@ def test_basic_table_meta_properties(self): self.assertEqual([], tablemeta.clustering_key) self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) - parser = get_schema_parser(self.cluster.control_connection._connection, 1) + cc = self.cluster.control_connection._connection + parser = get_schema_parser( + cc, + self.cluster.metadata.get_host(cc.host).release_version, + self.cluster.metadata.get_host(cc.host).dse_version, + 1 + ) for option in tablemeta.options: self.assertIn(option, parser.recognized_table_options) @@ -189,8 +305,8 @@ def test_composite_in_compound_primary_key(self): self.check_create_statement(tablemeta, create_statement) def test_compound_primary_keys_compact(self): - create_statement = self.make_create_statement(["a"], ["b"], ["c"], compact=True) - create_statement += " AND CLUSTERING ORDER BY (b ASC)" + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -213,8 +329,8 @@ def test_cluster_column_ordering_reversed_metadata(self): @test_category metadata """ - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"], compact=True) - create_statement += " AND CLUSTERING ORDER BY (b ASC, c DESC)" + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c DESC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() b_column = tablemeta.columns['b'] @@ -223,8 +339,8 @@ def test_cluster_column_ordering_reversed_metadata(self): self.assertTrue(c_column.is_reversed) def test_compound_primary_keys_more_columns_compact(self): - create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"], compact=True) - create_statement += " AND CLUSTERING ORDER BY (b ASC, c ASC)" + create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -235,7 +351,7 @@ def test_compound_primary_keys_more_columns_compact(self): self.check_create_statement(tablemeta, create_statement) def test_composite_primary_key_compact(self): - create_statement = self.make_create_statement(["a", "b"], [], ["c"], compact=True) + create_statement = self.make_create_statement(["a", "b"], [], ["c"]) self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -246,8 +362,8 @@ def test_composite_primary_key_compact(self): self.check_create_statement(tablemeta, create_statement) def test_composite_in_compound_primary_key_compact(self): - create_statement = self.make_create_statement(["a", "b"], ["c"], ["d"], compact=True) - create_statement += " AND CLUSTERING ORDER BY (c ASC)" + create_statement = self.make_create_statement(["a", "b"], ["c"], ["d"]) + create_statement += " WITH CLUSTERING ORDER BY (c ASC)" self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -257,13 +373,12 @@ def test_composite_in_compound_primary_key_compact(self): self.check_create_statement(tablemeta, create_statement) + @lessthancass30 def test_cql_compatibility(self): - if CASS_SERVER_VERSION >= (3, 0): - raise unittest.SkipTest("cql compatibility does not apply Cassandra 3.0+") # having more than one non-PK column is okay if there aren't any # clustering columns - create_statement = self.make_create_statement(["a"], [], ["b", "c", "d"], compact=True) + create_statement = self.make_create_statement(["a"], [], ["b", "c", "d"]) self.session.execute(create_statement) tablemeta = self.get_table_metadata() @@ -273,12 +388,12 @@ def test_cql_compatibility(self): self.assertTrue(tablemeta.is_cql_compatible) - # ... but if there are clustering columns, it's not CQL compatible. - # This is a hacky way to simulate having clustering columns. + # It will be cql compatible after CASSANDRA-10857 + # since compact storage is being dropped tablemeta.clustering_key = ["foo", "bar"] tablemeta.columns["foo"] = None tablemeta.columns["bar"] = None - self.assertFalse(tablemeta.is_cql_compatible) + self.assertTrue(tablemeta.is_cql_compatible) def test_compound_primary_keys_ordering(self): create_statement = self.make_create_statement(["a"], ["b"], ["c"]) @@ -301,6 +416,57 @@ def test_composite_in_compound_primary_key_ordering(self): tablemeta = self.get_table_metadata() self.check_create_statement(tablemeta, create_statement) + @lessthancass40 + def test_compact_storage(self): + create_statement = self.make_create_statement(["a"], [], ["b"]) + create_statement += " WITH COMPACT STORAGE" + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + @lessthancass40 + def test_dense_compact_storage(self): + create_statement = self.make_create_statement(["a"], ["b"], ["c"]) + create_statement += " WITH COMPACT STORAGE" + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + def test_counter(self): + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text PRIMARY KEY, a1 counter)" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + @lessthancass40 + def test_counter_with_compact_storage(self): + """ PYTHON-1100 """ + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text PRIMARY KEY, a1 counter) WITH COMPACT STORAGE" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + + @lessthancass40 + def test_counter_with_dense_compact_storage(self): + create_statement = ( + "CREATE TABLE {keyspace}.{table} (" + "key text, c1 text, a1 counter, PRIMARY KEY (key, c1)) WITH COMPACT STORAGE" + ).format(keyspace=self.keyspace_name, table=self.function_table_name) + + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + self.check_create_statement(tablemeta, create_statement) + def test_indexes(self): create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)" @@ -325,9 +491,8 @@ def test_indexes(self): self.assertIn('CREATE INDEX d_index', statement) self.assertIn('CREATE INDEX e_index', statement) + @greaterthancass21 def test_collection_indexes(self): - if CASS_SERVER_VERSION < (2, 1, 0): - raise unittest.SkipTest("Secondary index on collections were introduced in Cassandra 2.1") self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b map)" % (self.keyspace_name, self.function_table_name)) @@ -342,11 +507,11 @@ def test_collection_indexes(self): % (self.keyspace_name, self.function_table_name)) tablemeta = self.get_table_metadata() - target = ' (b)' if CASS_SERVER_VERSION < (3, 0) else 'values(b))' # explicit values in C* 3+ + target = ' (b)' if CASSANDRA_VERSION < Version("3.0") else 'values(b))' # explicit values in C* 3+ self.assertIn(target, tablemeta.export_as_string()) # test full indexes on frozen collections, if available - if CASS_SERVER_VERSION >= (2, 1, 3): + if CASSANDRA_VERSION >= Version("2.1.3"): self.session.execute("DROP TABLE %s.%s" % (self.keyspace_name, self.function_table_name)) self.session.execute("CREATE TABLE %s.%s (a int PRIMARY KEY, b frozen>)" % (self.keyspace_name, self.function_table_name)) @@ -361,7 +526,7 @@ def test_compression_disabled(self): create_statement += " WITH compression = {}" self.session.execute(create_statement) tablemeta = self.get_table_metadata() - expected = "compression = {}" if CASS_SERVER_VERSION < (3, 0) else "compression = {'enabled': 'false'}" + expected = "compression = {}" if CASSANDRA_VERSION < Version("3.0") else "compression = {'enabled': 'false'}" self.assertIn(expected, tablemeta.export_as_string()) def test_non_size_tiered_compaction(self): @@ -385,8 +550,10 @@ def test_non_size_tiered_compaction(self): cql = table_meta.export_as_string() self.assertIn("'tombstone_threshold': '0.3'", cql) self.assertIn("LeveledCompactionStrategy", cql) - self.assertNotIn("min_threshold", cql) - self.assertNotIn("max_threshold", cql) + # formerly legacy options; reintroduced in 4.0 + if CASSANDRA_VERSION < Version('4.0-a'): + self.assertNotIn("min_threshold", cql) + self.assertNotIn("max_threshold", cql) def test_refresh_schema_metadata(self): """ @@ -406,8 +573,7 @@ def test_refresh_schema_metadata(self): @test_category metadata """ - - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) @@ -447,7 +613,7 @@ def test_refresh_schema_metadata(self): self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) cluster2.refresh_schema_metadata() @@ -490,7 +656,7 @@ def test_refresh_keyspace_metadata(self): @test_category metadata """ - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) @@ -505,11 +671,11 @@ def test_refresh_table_metadata(self): """ test for synchronously refreshing table metadata - test_refresh_table_metatadata tests that table metadata is refreshed when calling test_refresh_table_metatadata(). + test_refresh_table_metadata tests that table metadata is refreshed when calling test_refresh_table_metadata(). It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled for schema change push events. It then alters the table, adding a new column, using the first cluster object, and verifies that the table metadata has not changed in the second cluster object. Finally, it calls - test_refresh_table_metatadata() and verifies that the table metadata is updated in the second cluster object. + test_refresh_table_metadata() and verifies that the table metadata is updated in the second cluster object. @since 2.6.0 @jira_ticket PYTHON-291 @@ -521,7 +687,7 @@ def test_refresh_table_metadata(self): table_name = "test" self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, table_name)) - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) @@ -533,15 +699,16 @@ def test_refresh_table_metadata(self): cluster2.shutdown() + @greaterthanorequalcass30 def test_refresh_metadata_for_mv(self): """ test for synchronously refreshing materialized view metadata test_refresh_table_metadata_for_materialized_views tests that materialized view metadata is refreshed when calling - test_refresh_table_metatadata() with the materialized view name as the table. It creates a second cluster object + test_refresh_table_metadata() with the materialized view name as the table. It creates a second cluster object with schema_event_refresh_window=-1 such that schema refreshes are disabled for schema change push events. It then creates a new materialized view , using the first cluster object, and verifies that the materialized view - metadata has not changed in the second cluster object. Finally, it calls test_refresh_table_metatadata() with the + metadata has not changed in the second cluster object. Finally, it calls test_refresh_table_metadata() with the materialized view name as the table name, and verifies that the materialized view metadata is updated in the second cluster object. @@ -552,17 +719,15 @@ def test_refresh_metadata_for_mv(self): @test_category metadata """ - if CASS_SERVER_VERSION < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") - self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b text)".format(self.keyspace_name, self.function_table_name)) - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() try: self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT b FROM {0}.{1} WHERE b IS NOT NULL PRIMARY KEY (a, b)" + self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " + "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)" .format(self.keyspace_name, self.function_table_name)) self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) @@ -580,12 +745,15 @@ def test_refresh_metadata_for_mv(self): self.assertIsNot(original_meta, self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1']) self.assertEqual(original_meta.as_cql_query(), current_meta.as_cql_query()) - cluster3 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster3 = TestCluster(schema_event_refresh_window=-1) cluster3.connect() try: self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT b FROM {0}.{1} WHERE b IS NOT NULL PRIMARY KEY (a, b)" - .format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT a, b FROM {0}.{1} " + "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( + self.keyspace_name, self.function_table_name) + ) self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) cluster3.refresh_materialized_view_metadata(self.keyspace_name, 'mv2') self.assertIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) @@ -612,7 +780,7 @@ def test_refresh_user_type_metadata(self): if PROTOCOL_VERSION < 3: raise unittest.SkipTest("Protocol 3+ is required for UDTs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) @@ -624,6 +792,41 @@ def test_refresh_user_type_metadata(self): cluster2.shutdown() + @greaterthancass20 + def test_refresh_user_type_metadata_proto_2(self): + """ + Test to insure that protocol v1/v2 surface UDT metadata changes + + @since 3.7.0 + @jira_ticket PYTHON-106 + @expected_result UDT metadata in the keyspace should be updated regardless of protocol version + + @test_category metadata + """ + supported_versions = get_supported_protocol_versions() + if 2 not in supported_versions: # 1 and 2 were dropped in the same version + raise unittest.SkipTest("Protocol versions 1 and 2 are not supported in Cassandra version ".format(CASSANDRA_VERSION)) + + for protocol_version in (1, 2): + cluster = TestCluster() + session = cluster.connect() + self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + + session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + self.assertIn("user", cluster.metadata.keyspaces[self.keyspace_name].user_types) + self.assertIn("age", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + self.assertIn("name", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name)) + self.assertIn("flag", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("ALTER TYPE {0}.user RENAME flag TO something".format(self.keyspace_name)) + self.assertIn("something", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + + session.execute("DROP TYPE {0}.user".format(self.keyspace_name)) + self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + cluster.shutdown() + def test_refresh_user_function_metadata(self): """ test for synchronously refreshing UDF metadata in keyspace @@ -645,14 +848,14 @@ def test_refresh_user_function_metadata(self): if PROTOCOL_VERSION < 4: raise unittest.SkipTest("Protocol 4+ is required for UDFs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + LANGUAGE java AS ' return key + val;';""".format(self.keyspace_name)) self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) cluster2.refresh_user_function_metadata(self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"])) @@ -681,14 +884,14 @@ def test_refresh_user_aggregate_metadata(self): if PROTOCOL_VERSION < 4: raise unittest.SkipTest("Protocol 4+ is required for UDAs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = Cluster(protocol_version=PROTOCOL_VERSION, schema_event_refresh_window=-1) + cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 'key + val';""".format(self.keyspace_name)) + LANGUAGE java AS 'return key + val;';""".format(self.keyspace_name)) self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) SFUNC sum_int @@ -702,6 +905,7 @@ def test_refresh_user_aggregate_metadata(self): cluster2.shutdown() + @greaterthanorequalcass30 def test_multiple_indices(self): """ test multiple indices on the same column. @@ -714,8 +918,6 @@ def test_multiple_indices(self): @test_category metadata """ - if CASS_SERVER_VERSION < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") self.session.execute("CREATE TABLE {0}.{1} (a int PRIMARY KEY, b map)".format(self.keyspace_name, self.function_table_name)) self.session.execute("CREATE INDEX index_1 ON {0}.{1}(b)".format(self.keyspace_name, self.function_table_name)) @@ -736,6 +938,104 @@ def test_multiple_indices(self): self.assertEqual(index_2.index_options["target"], "keys(b)") self.assertEqual(index_2.keyspace_name, "schemametadatatests") + @greaterthanorequalcass30 + def test_table_extensions(self): + s = self.session + ks = self.keyspace_name + ks_meta = s.cluster.metadata.keyspaces[ks] + t = self.function_table_name + v = t + 'view' + + s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t)) + s.execute( + "CREATE MATERIALIZED VIEW %s.%s AS SELECT * FROM %s.%s " + "WHERE v IS NOT NULL AND k IS NOT NULL PRIMARY KEY (v, k)" % (ks, v, ks, t) + ) + + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertFalse(table_meta.extensions) + self.assertFalse(view_meta.extensions) + + original_table_cql = table_meta.export_as_string() + original_view_cql = view_meta.export_as_string() + + # extensions registered, not present + # -------------------------------------- + class Ext0(RegisteredTableExtension): + name = t + + @classmethod + def after_table_cql(cls, table_meta, ext_key, ext_blob): + return "%s %s %s %s" % (cls.name, table_meta.name, ext_key, ext_blob) + + class Ext1(Ext0): + name = t + '##' + + self.assertFalse(table_meta.extensions) + self.assertFalse(view_meta.extensions) + self.assertIn(Ext0.name, _RegisteredExtensionType._extension_registry) + self.assertIn(Ext1.name, _RegisteredExtensionType._extension_registry) + # There will bee the RLAC extension here. + self.assertEqual(len(_RegisteredExtensionType._extension_registry), 3) + + self.cluster.refresh_table_metadata(ks, t) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertEqual(table_meta.export_as_string(), original_table_cql) + self.assertEqual(view_meta.export_as_string(), original_view_cql) + + update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing + update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?') + # extensions registered, one present + # -------------------------------------- + ext_map = {Ext0.name: b"THA VALUE"} + [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) + for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + self.cluster.refresh_table_metadata(ks, t) + self.cluster.refresh_materialized_view_metadata(ks, v) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertIn(Ext0.name, table_meta.extensions) + new_cql = table_meta.export_as_string() + self.assertNotEqual(new_cql, original_table_cql) + self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertNotIn(Ext1.name, new_cql) + + self.assertIn(Ext0.name, view_meta.extensions) + new_cql = view_meta.export_as_string() + self.assertNotEqual(new_cql, original_view_cql) + self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertNotIn(Ext1.name, new_cql) + + # extensions registered, one present + # -------------------------------------- + ext_map = {Ext0.name: b"THA VALUE", + Ext1.name: b"OTHA VALUE"} + [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v))) + for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts + self.cluster.refresh_table_metadata(ks, t) + self.cluster.refresh_materialized_view_metadata(ks, v) + table_meta = ks_meta.tables[t] + view_meta = table_meta.views[v] + + self.assertIn(Ext0.name, table_meta.extensions) + self.assertIn(Ext1.name, table_meta.extensions) + new_cql = table_meta.export_as_string() + self.assertNotEqual(new_cql, original_table_cql) + self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertIn(Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + + self.assertIn(Ext0.name, view_meta.extensions) + self.assertIn(Ext1.name, view_meta.extensions) + new_cql = view_meta.export_as_string() + self.assertNotEqual(new_cql, original_view_cql) + self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) + self.assertIn(Ext1.after_table_cql(view_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + class TestCodeCoverage(unittest.TestCase): @@ -744,23 +1044,24 @@ def test_export_schema(self): Test export schema functionality """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() cluster.connect() - self.assertIsInstance(cluster.metadata.export_schema_as_string(), six.string_types) + self.assertIsInstance(cluster.metadata.export_schema_as_string(), str) + cluster.shutdown() def test_export_keyspace_schema(self): """ Test export keyspace schema functionality """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() cluster.connect() for keyspace in cluster.metadata.keyspaces: keyspace_metadata = cluster.metadata.keyspaces[keyspace] - self.assertIsInstance(keyspace_metadata.export_as_string(), six.string_types) - self.assertIsInstance(keyspace_metadata.as_cql_query(), six.string_types) + self.assertIsInstance(keyspace_metadata.export_as_string(), str) + self.assertIsInstance(keyspace_metadata.as_cql_query(), str) cluster.shutdown() def assert_equal_diff(self, received, expected): @@ -780,14 +1081,12 @@ def assert_startswith_diff(self, received, prefix): lineterm='')) self.fail(diff_string) + @greaterthancass20 def test_export_keyspace_schema_udts(self): """ Test udt exports """ - if CASS_SERVER_VERSION < (2, 1, 0): - raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') - if PROTOCOL_VERSION < 3: raise unittest.SkipTest( "Protocol 3.0+ is required for UDT change events, currently testing against %r" @@ -796,7 +1095,7 @@ def test_export_keyspace_schema_udts(self): if sys.version_info[0:2] != (2, 7): raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() session.execute(""" @@ -858,17 +1157,19 @@ def test_export_keyspace_schema_udts(self): cluster.shutdown() + @greaterthancass21 def test_case_sensitivity(self): """ Test that names that need to be escaped in CREATE statements are """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() ksname = 'AnInterestingKeyspace' cfname = 'AnInterestingTable' + session.execute("DROP KEYSPACE IF EXISTS {0}".format(ksname)) session.execute(""" CREATE KEYSPACE "%s" WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} @@ -885,6 +1186,9 @@ def test_case_sensitivity(self): session.execute(""" CREATE INDEX myindex ON "%s"."%s" ("MyColumn") """ % (ksname, cfname)) + session.execute(""" + CREATE INDEX "AnotherIndex" ON "%s"."%s" ("B") + """ % (ksname, cfname)) ksmeta = cluster.metadata.keyspaces[ksname] schema = ksmeta.export_as_string() @@ -896,6 +1200,7 @@ def test_case_sensitivity(self): self.assertIn('PRIMARY KEY (k, "A")', schema) self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema) self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema) + self.assertIn('CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")', schema) cluster.shutdown() def test_already_exists_exceptions(self): @@ -903,7 +1208,7 @@ def test_already_exists_exceptions(self): Ensure AlreadyExists exception is thrown when hit """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() ksname = 'test3rf' @@ -921,6 +1226,7 @@ def test_already_exists_exceptions(self): self.assertRaises(AlreadyExists, session.execute, ddl % (ksname, cfname)) cluster.shutdown() + @local def test_replicas(self): """ Ensure cluster.metadata.get_replicas return correctly when not attached to keyspace @@ -928,13 +1234,13 @@ def test_replicas(self): if murmur3 is None: raise unittest.SkipTest('the murmur3 extension is not available') - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), []) cluster.connect('test3rf') - self.assertNotEqual(list(cluster.metadata.get_replicas('test3rf', six.b('key'))), []) - host = list(cluster.metadata.get_replicas('test3rf', six.b('key')))[0] + self.assertNotEqual(list(cluster.metadata.get_replicas('test3rf', b'key')), []) + host = list(cluster.metadata.get_replicas('test3rf', b'key'))[0] self.assertEqual(host.datacenter, 'dc1') self.assertEqual(host.rack, 'r1') cluster.shutdown() @@ -944,7 +1250,7 @@ def test_token_map(self): Test token mappings """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() cluster.connect('test3rf') ring = cluster.metadata.token_map.ring owners = list(cluster.metadata.token_map.token_to_host_owner[token] for token in ring) @@ -955,277 +1261,8 @@ def test_token_map(self): for i, token in enumerate(ring): self.assertEqual(set(get_replicas('test3rf', token)), set(owners)) - self.assertEqual(set(get_replicas('test2rf', token)), set([owners[(i + 1) % 3], owners[(i + 2) % 3]])) - self.assertEqual(set(get_replicas('test1rf', token)), set([owners[(i + 1) % 3]])) - cluster.shutdown() - - def test_legacy_tables(self): - - if CASS_SERVER_VERSION < (2, 1, 0): - raise unittest.SkipTest('Test schema output assumes 2.1.0+ options') - - if CASS_SERVER_VERSION >= (2, 2, 0): - raise unittest.SkipTest('Cannot test cli script on Cassandra 2.2.0+') - - if sys.version_info[0:2] != (2, 7): - raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') - - cli_script = """CREATE KEYSPACE legacy -WITH placement_strategy = 'SimpleStrategy' -AND strategy_options = {replication_factor:1}; - -USE legacy; - -CREATE COLUMN FAMILY simple_no_col - WITH comparator = UTF8Type - AND key_validation_class = UUIDType - AND default_validation_class = UTF8Type; - -CREATE COLUMN FAMILY simple_with_col - WITH comparator = UTF8Type - and key_validation_class = UUIDType - and default_validation_class = UTF8Type - AND column_metadata = [ - {column_name: col_with_meta, validation_class: UTF8Type} - ]; - -CREATE COLUMN FAMILY composite_partition_no_col - WITH comparator = UTF8Type - AND key_validation_class = 'CompositeType(UUIDType,UTF8Type)' - AND default_validation_class = UTF8Type; - -CREATE COLUMN FAMILY composite_partition_with_col - WITH comparator = UTF8Type - AND key_validation_class = 'CompositeType(UUIDType,UTF8Type)' - AND default_validation_class = UTF8Type - AND column_metadata = [ - {column_name: col_with_meta, validation_class: UTF8Type} - ]; - -CREATE COLUMN FAMILY nested_composite_key - WITH comparator = UTF8Type - and key_validation_class = 'CompositeType(CompositeType(UUIDType,UTF8Type), LongType)' - and default_validation_class = UTF8Type - AND column_metadata = [ - {column_name: full_name, validation_class: UTF8Type} - ]; - -create column family composite_comp_no_col - with column_type = 'Standard' - and comparator = 'DynamicCompositeType(t=>org.apache.cassandra.db.marshal.TimeUUIDType,s=>org.apache.cassandra.db.marshal.UTF8Type,b=>org.apache.cassandra.db.marshal.BytesType)' - and default_validation_class = 'BytesType' - and key_validation_class = 'BytesType' - and read_repair_chance = 0.0 - and dclocal_read_repair_chance = 0.1 - and gc_grace = 864000 - and min_compaction_threshold = 4 - and max_compaction_threshold = 32 - and compaction_strategy = 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy' - and caching = 'KEYS_ONLY' - and cells_per_row_to_cache = '0' - and default_time_to_live = 0 - and speculative_retry = 'NONE' - and comment = 'Stores file meta data'; - -create column family composite_comp_with_col - with column_type = 'Standard' - and comparator = 'DynamicCompositeType(t=>org.apache.cassandra.db.marshal.TimeUUIDType,s=>org.apache.cassandra.db.marshal.UTF8Type,b=>org.apache.cassandra.db.marshal.BytesType)' - and default_validation_class = 'BytesType' - and key_validation_class = 'BytesType' - and read_repair_chance = 0.0 - and dclocal_read_repair_chance = 0.1 - and gc_grace = 864000 - and min_compaction_threshold = 4 - and max_compaction_threshold = 32 - and compaction_strategy = 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy' - and caching = 'KEYS_ONLY' - and cells_per_row_to_cache = '0' - and default_time_to_live = 0 - and speculative_retry = 'NONE' - and comment = 'Stores file meta data' - and column_metadata = [ - {column_name : 'b@6d616d6d616a616d6d61', - validation_class : BytesType, - index_name : 'idx_one', - index_type : 0}, - {column_name : 'b@6869746d65776974686d75736963', - validation_class : BytesType, - index_name : 'idx_two', - index_type : 0}] - and compression_options = {'sstable_compression' : 'org.apache.cassandra.io.compress.LZ4Compressor'};""" - - # note: the inner key type for legacy.nested_composite_key - # (org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type)) - # is a bit strange, but it replays in CQL with desired results - expected_string = """CREATE KEYSPACE legacy WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} AND durable_writes = true; - -/* -Warning: Table legacy.composite_comp_with_col omitted because it has constructs not compatible with CQL (was created via legacy API). - -Approximate structure, for reference: -(this should not be used to reproduce this schema) - -CREATE TABLE legacy.composite_comp_with_col ( - key blob, - t timeuuid, - b blob, - s text, - "b@6869746d65776974686d75736963" blob, - "b@6d616d6d616a616d6d61" blob, - PRIMARY KEY (key, t, b, s) -) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (t ASC, b ASC, s ASC) - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = 'Stores file meta data' - AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; -CREATE INDEX idx_two ON legacy.composite_comp_with_col ("b@6869746d65776974686d75736963"); -CREATE INDEX idx_one ON legacy.composite_comp_with_col ("b@6d616d6d616a616d6d61"); -*/ - -CREATE TABLE legacy.nested_composite_key ( - key 'org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType, org.apache.cassandra.db.marshal.UTF8Type)', - key2 bigint, - full_name text, - PRIMARY KEY ((key, key2)) -) WITH COMPACT STORAGE - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = '' - AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; - -CREATE TABLE legacy.composite_partition_with_col ( - key uuid, - key2 text, - col_with_meta text, - PRIMARY KEY ((key, key2)) -) WITH COMPACT STORAGE - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = '' - AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; - -CREATE TABLE legacy.composite_partition_no_col ( - key uuid, - key2 text, - column1 text, - value text, - PRIMARY KEY ((key, key2), column1) -) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC) - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = '' - AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; - -CREATE TABLE legacy.simple_with_col ( - key uuid PRIMARY KEY, - col_with_meta text -) WITH COMPACT STORAGE - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = '' - AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; - -CREATE TABLE legacy.simple_no_col ( - key uuid, - column1 text, - value text, - PRIMARY KEY (key, column1) -) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC) - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = '' - AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; - -/* -Warning: Table legacy.composite_comp_no_col omitted because it has constructs not compatible with CQL (was created via legacy API). - -Approximate structure, for reference: -(this should not be used to reproduce this schema) - -CREATE TABLE legacy.composite_comp_no_col ( - key blob, - column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.TimeUUIDType, org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type)', - column2 text, - value blob, - PRIMARY KEY (key, column1, column1, column2) -) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC, column1 ASC, column2 ASC) - AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' - AND comment = 'Stores file meta data' - AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} - AND compression = {'sstable_compression': 'org.apache.cassandra.io.compress.LZ4Compressor'} - AND dclocal_read_repair_chance = 0.1 - AND default_time_to_live = 0 - AND gc_grace_seconds = 864000 - AND max_index_interval = 2048 - AND memtable_flush_period_in_ms = 0 - AND min_index_interval = 128 - AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; -*/""" - - ccm = get_cluster() - ccm.run_cli(cli_script) - - cluster = Cluster(protocol_version=PROTOCOL_VERSION) - session = cluster.connect() - - legacy_meta = cluster.metadata.keyspaces['legacy'] - self.assert_equal_diff(legacy_meta.export_as_string(), expected_string) - - session.execute('DROP KEYSPACE legacy') - + self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]])) + self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]])) cluster.shutdown() @@ -1233,49 +1270,24 @@ class TokenMetadataTest(unittest.TestCase): """ Test of TokenMap creation and other behavior. """ - + @local def test_token(self): expected_node_count = len(get_cluster().nodes) - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() cluster.connect() tmap = cluster.metadata.token_map self.assertTrue(issubclass(tmap.token_class, Token)) self.assertEqual(expected_node_count, len(tmap.ring)) cluster.shutdown() - def test_getting_replicas(self): - tokens = [MD5Token(str(i)) for i in range(0, (2 ** 127 - 1), 2 ** 125)] - hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] - token_to_primary_replica = dict(zip(tokens, hosts)) - keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) - metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) - token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, metadata) - - # tokens match node tokens exactly - for i, token in enumerate(tokens): - expected_host = hosts[(i + 1) % len(hosts)] - replicas = token_map.get_replicas("ks", token) - self.assertEqual(set(replicas), set([expected_host])) - - # shift the tokens back by one - for token, expected_host in zip(tokens, hosts): - replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1))) - self.assertEqual(set(replicas), set([expected_host])) - - # shift the tokens forward by one - for i, token in enumerate(tokens): - replicas = token_map.get_replicas("ks", MD5Token(str(token.value + 1))) - expected_host = hosts[(i + 1) % len(hosts)] - self.assertEqual(set(replicas), set([expected_host])) - class KeyspaceAlterMetadata(unittest.TestCase): """ Test verifies that table metadata is preserved on keyspace alter """ def setUp(self): - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() self.session = self.cluster.connect() name = self._testMethodName.lower() crt_ks = ''' @@ -1320,7 +1332,7 @@ def table_name(self): @classmethod def setup_class(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect() try: if cls.keyspace_name in cls.cluster.metadata.keyspaces: @@ -1429,7 +1441,7 @@ def function_name(self): @classmethod def setup_class(cls): if PROTOCOL_VERSION >= 4: - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) @@ -1491,7 +1503,10 @@ def make_function_kwargs(self, called_on_null=True): 'return_type': 'double', 'language': 'java', 'body': 'return new Double(0.0);', - 'called_on_null_input': called_on_null} + 'called_on_null_input': called_on_null, + 'deterministic': False, + 'monotonic': False, + 'monotonic_on': []} def test_functions_after_udt(self): """ @@ -1576,7 +1591,7 @@ def test_function_no_parameters(self): with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*%s\(\) .*" % kwargs['name']) + self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*%s\(\) .*" % kwargs['name']) def test_functions_follow_keyspace_alter(self): """ @@ -1624,12 +1639,12 @@ def test_function_cql_called_on_null(self): kwargs['called_on_null_input'] = True with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*") + self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*") kwargs['called_on_null_input'] = False with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegexpMatches(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*") + self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*") class AggregateMetadata(FunctionTest): @@ -1642,15 +1657,15 @@ def setup_class(cls): cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int(s int, i int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 's + i';""") + LANGUAGE java AS 'return s + i;';""") cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int_two(s int, i int, j int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 's + i + j';""") + LANGUAGE java AS 'return s + i + j;';""") cls.session.execute("""CREATE OR REPLACE FUNCTION "List_As_String"(l list) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS ''''' + l';""") + LANGUAGE java AS 'return l.size();';""") cls.session.execute("""CREATE OR REPLACE FUNCTION extend_list(s list, i int) CALLED ON NULL INPUT RETURNS list @@ -1673,7 +1688,8 @@ def make_aggregate_kwargs(self, state_func, state_type, final_func=None, init_co 'state_type': state_type, 'final_func': final_func, 'initial_condition': init_cond, - 'return_type': "does not matter for creation"} + 'return_type': "does not matter for creation", + 'deterministic': False} def test_return_type_meta(self): """ @@ -1707,7 +1723,7 @@ def test_init_cond(self): """ # This is required until the java driver bundled with C* is updated to support v4 - c = Cluster(protocol_version=3) + c = TestCluster(protocol_version=3) s = c.connect(self.keyspace_name) encoder = Encoder() @@ -1737,9 +1753,9 @@ def test_init_cond(self): cql_init = encoder.cql_encode_all_types(init_cond) with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map', init_cond=cql_init)) as va: map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name'])[0].map_res - self.assertDictContainsSubset(expected_map_values, map_res) + self.assertLessEqual(expected_map_values.items(), map_res.items()) init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set) - self.assertDictContainsSubset(init_not_updated, map_res) + self.assertLessEqual(init_not_updated.items(), map_res.items()) c.shutdown() def test_aggregates_after_functions(self): @@ -1891,22 +1907,25 @@ def function_name(self): @classmethod def setup_class(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.keyspace_name = cls.__name__.lower() cls.session = cls.cluster.connect() cls.session.execute("CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) cls.session.set_keyspace(cls.keyspace_name) connection = cls.cluster.control_connection._connection - cls.parser_class = get_schema_parser(connection, timeout=20).__class__ + + cls.parser_class = get_schema_parser( + connection, + cls.cluster.metadata.get_host(connection.host).release_version, + cls.cluster.metadata.get_host(connection.host).dse_version, + timeout=20 + ).__class__ + cls.cluster.control_connection.reconnect = Mock() @classmethod def teardown_class(cls): drop_keyspace_shutdown_cluster(cls.keyspace_name, cls.session, cls.cluster) - def _skip_if_not_version(self, version): - if CASS_SERVER_VERSION < version: - raise unittest.SkipTest("Requires server version >= %s" % (version,)) - def test_bad_keyspace(self): with patch.object(self.parser_class, '_build_keyspace_metadata_internal', side_effect=self.BadMetaException): self.cluster.refresh_keyspace_metadata(self.keyspace_name) @@ -1931,8 +1950,8 @@ def test_bad_index(self): self.assertIs(m._exc_info[0], self.BadMetaException) self.assertIn("/*\nWarning:", m.export_as_string()) + @greaterthancass20 def test_bad_user_type(self): - self._skip_if_not_version((2, 1, 0)) self.session.execute('CREATE TYPE %s (i int, d double)' % self.function_name) with patch.object(self.parser_class, '_build_user_type', side_effect=self.BadMetaException): self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh @@ -1940,42 +1959,85 @@ def test_bad_user_type(self): self.assertIs(m._exc_info[0], self.BadMetaException) self.assertIn("/*\nWarning:", m.export_as_string()) + @greaterthancass21 def test_bad_user_function(self): - self._skip_if_not_version((2, 2, 0)) self.session.execute("""CREATE FUNCTION IF NOT EXISTS %s (key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 'key + val';""" % self.function_name) - with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh - m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) - + LANGUAGE java AS 'return key + val;';""" % self.function_name) + + #We need to patch as well the reconnect function because after patching the _build_function + #there will an Error refreshing schema which will trigger a reconnection. If this happened + #in a timely manner in the call self.cluster.refresh_schema_metadata() it would return an exception + #due to that a connection would be closed + with patch.object(self.cluster.control_connection, 'reconnect'): + with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + + @greaterthancass21 def test_bad_user_aggregate(self): - self._skip_if_not_version((2, 2, 0)) self.session.execute("""CREATE FUNCTION IF NOT EXISTS sum_int (key int, val int) RETURNS NULL ON NULL INPUT RETURNS int - LANGUAGE javascript AS 'key + val';""") + LANGUAGE java AS 'return key + val;';""") self.session.execute("""CREATE AGGREGATE %s(int) SFUNC sum_int STYPE int INITCOND 0""" % self.function_name) - with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): - self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh - m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + #We have the same issue here as in test_bad_user_function + with patch.object(self.cluster.control_connection, 'reconnect'): + with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): + self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh + m = self.cluster.metadata.keyspaces[self.keyspace_name] + self.assertIs(m._exc_info[0], self.BadMetaException) + self.assertIn("/*\nWarning:", m.export_as_string()) + +class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): + + def test_dct_alias(self): + """ + Tests to make sure DCT's have correct string formatting + Constructs a DCT and check the format as generated. To insure it matches what is expected + + @since 3.6.0 + @jira_ticket PYTHON-579 + @expected_result DCT subtypes should always have fully qualified names + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.{1} (" + "k int PRIMARY KEY," + "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," + "c2 Text)".format(self.ks_name, self.function_table_name)) + dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) + + # Format can very slightly between versions, strip out whitespace for consistency sake + table_text = dct_table.as_cql_query().replace(" ", "") + dynamic_type_text = "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" + self.assertIn("c1'org.apache.cassandra.db.marshal.DynamicCompositeType(", table_text) + # Types within in the composite can come out in random order, so grab the type definition and find each one + type_definition_start = table_text.index("(", table_text.find(dynamic_type_text)) + type_definition_end = table_text.index(")") + type_definition_text = table_text[type_definition_start:type_definition_end] + self.assertIn("s=>org.apache.cassandra.db.marshal.UTF8Type", type_definition_text) + self.assertIn("i=>org.apache.cassandra.db.marshal.Int32Type", type_definition_text) + + +@greaterthanorequalcass30 class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): def setUp(self): - if CASS_SERVER_VERSION < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") self.session.execute("CREATE TABLE {0}.{1} (pk int PRIMARY KEY, c int)".format(self.keyspace_name, self.function_table_name)) - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT c FROM {0}.{1} WHERE c IS NOT NULL PRIMARY KEY (pk, c)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " + "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c)".format( + self.keyspace_name, self.function_table_name) + ) def tearDown(self): self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) @@ -2018,10 +2080,15 @@ def test_materialized_view_metadata_alter(self): @test_category metadata """ - self.assertIn("SizeTieredCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] ) + compaction = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + if HCD_VERSION: + self.assertIn("UnifiedCompactionStrategy", compaction) + else: + self.assertIn("SizeTieredCompactionStrategy", compaction) self.session.execute("ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format(self.keyspace_name)) - self.assertIn("LeveledCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"]) + compaction = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] + self.assertIn("LeveledCompactionStrategy", compaction) def test_materialized_view_metadata_drop(self): """ @@ -2046,15 +2113,15 @@ def test_materialized_view_metadata_drop(self): self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) - self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT c FROM {0}.{1} WHERE c IS NOT NULL PRIMARY KEY (pk, c)".format(self.keyspace_name, self.function_table_name)) + self.session.execute( + "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " + "WHERE pk IS NOT NULL AND c IS NOT NULL PRIMARY KEY (pk, c)".format( + self.keyspace_name, self.function_table_name) + ) +@greaterthanorequalcass30 class MaterializedViewMetadataTestComplex(BasicSegregatedKeyspaceUnitTestCase): - def setUp(self): - if CASS_SERVER_VERSION < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") - super(MaterializedViewMetadataTestComplex, self).setUp() - def test_create_view_metadata(self): """ test to ensure that materialized view metadata is properly constructed @@ -2117,37 +2184,37 @@ def test_create_view_metadata(self): self.assertIsNotNone(score_table.columns['score']) # Validate basic mv information - self.assertEquals(mv.keyspace_name, self.keyspace_name) - self.assertEquals(mv.name, "monthlyhigh") - self.assertEquals(mv.base_table_name, "scores") + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "monthlyhigh") + self.assertEqual(mv.base_table_name, "scores") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEquals(len(mv_columns), 6) + self.assertEqual(len(mv_columns), 6) game_column = mv_columns[0] self.assertIsNotNone(game_column) - self.assertEquals(game_column.name, 'game') - self.assertEquals(game_column, mv.partition_key[0]) + self.assertEqual(game_column.name, 'game') + self.assertEqual(game_column, mv.partition_key[0]) year_column = mv_columns[1] self.assertIsNotNone(year_column) - self.assertEquals(year_column.name, 'year') - self.assertEquals(year_column, mv.partition_key[1]) + self.assertEqual(year_column.name, 'year') + self.assertEqual(year_column, mv.partition_key[1]) month_column = mv_columns[2] self.assertIsNotNone(month_column) - self.assertEquals(month_column.name, 'month') - self.assertEquals(month_column, mv.partition_key[2]) + self.assertEqual(month_column.name, 'month') + self.assertEqual(month_column, mv.partition_key[2]) def compare_columns(a, b, name): - self.assertEquals(a.name, name) - self.assertEquals(a.name, b.name) - self.assertEquals(a.table, b.table) - self.assertEquals(a.cql_type, b.cql_type) - self.assertEquals(a.is_static, b.is_static) - self.assertEquals(a.is_reversed, b.is_reversed) + self.assertEqual(a.name, name) + self.assertEqual(a.name, b.name) + self.assertEqual(a.table, b.table) + self.assertEqual(a.cql_type, b.cql_type) + self.assertEqual(a.is_static, b.is_static) + self.assertEqual(a.is_reversed, b.is_reversed) score_column = mv_columns[3] compare_columns(score_column, mv.clustering_key[0], 'score') @@ -2193,7 +2260,7 @@ def test_base_table_column_addition_mv(self): SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) self.session.execute(create_mv) @@ -2224,8 +2291,9 @@ def test_base_table_column_addition_mv(self): self.assertIn("fouls", mv_alltime.columns) mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - self.assertEquals(mv_alltime_fouls_comumn.cql_type, 'int') + self.assertEqual(mv_alltime_fouls_comumn.cql_type, 'int') + @lessthancass30 def test_base_table_type_alter_mv(self): """ test to ensure that materialized view metadata is properly updated when a type in the base table @@ -2234,6 +2302,8 @@ def test_base_table_type_alter_mv(self): test_create_view_metadata tests that materialized views metadata is properly updated when the type of base table column is changed. + Support for alter type was removed in CASSANDRA-12443 + @since 3.0.0 @jira_ticket CASSANDRA-10424 @expected_result Materialized view metadata should be updated correctly @@ -2265,7 +2335,7 @@ def test_base_table_type_alter_mv(self): self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - self.assertEquals(score_column.cql_type, 'blob') + self.assertEqual(score_column.cql_type, 'blob') # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): @@ -2274,7 +2344,7 @@ def test_base_table_type_alter_mv(self): break time.sleep(.2) - self.assertEquals(score_mv_column.cql_type, 'blob') + self.assertEqual(score_mv_column.cql_type, 'blob') def test_metadata_with_quoted_identifiers(self): """ @@ -2327,28 +2397,136 @@ def test_metadata_with_quoted_identifiers(self): self.assertIsNotNone(t1_table.columns['the Value']) # Validate basic mv information - self.assertEquals(mv.keyspace_name, self.keyspace_name) - self.assertEquals(mv.name, "mv1") - self.assertEquals(mv.base_table_name, "t1") + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "mv1") + self.assertEqual(mv.base_table_name, "t1") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEquals(len(mv_columns), 3) + self.assertEqual(len(mv_columns), 3) theKey_column = mv_columns[0] self.assertIsNotNone(theKey_column) - self.assertEquals(theKey_column.name, 'theKey') - self.assertEquals(theKey_column, mv.partition_key[0]) + self.assertEqual(theKey_column.name, 'theKey') + self.assertEqual(theKey_column, mv.partition_key[0]) cluster_column = mv_columns[1] self.assertIsNotNone(cluster_column) - self.assertEquals(cluster_column.name, 'the;Clustering') - self.assertEquals(cluster_column.name, mv.clustering_key[0].name) - self.assertEquals(cluster_column.table, mv.clustering_key[0].table) - self.assertEquals(cluster_column.is_static, mv.clustering_key[0].is_static) - self.assertEquals(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) + self.assertEqual(cluster_column.name, 'the;Clustering') + self.assertEqual(cluster_column.name, mv.clustering_key[0].name) + self.assertEqual(cluster_column.table, mv.clustering_key[0].table) + self.assertEqual(cluster_column.is_static, mv.clustering_key[0].is_static) + self.assertEqual(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) value_column = mv_columns[2] self.assertIsNotNone(value_column) - self.assertEquals(value_column.name, 'the Value') + self.assertEqual(value_column.name, 'the Value') + + @greaterthanorequaldse51 + def test_dse_workloads(self): + """ + Test to ensure dse_workloads is populated appropriately. + Field added in DSE 5.1 + + @jira_ticket PYTHON-667 + @expected_result dse_workloads set is set on host model + + @test_category metadata + """ + for host in self.cluster.metadata.all_hosts(): + self.assertIsInstance(host.dse_workloads, SortedSet) + self.assertIn("Cassandra", host.dse_workloads) + + +class GroupPerHost(BasicSharedKeyspaceUnitTestCase): + @classmethod + def setUpClass(cls): + cls.common_setup(rf=1, create_class_table=True) + cls.table_two_pk = "table_with_two_pk" + cls.session.execute( + ''' + CREATE TABLE {0}.{1} ( + k_one int, + k_two int, + v int, + PRIMARY KEY ((k_one, k_two)) + )'''.format(cls.ks_name, cls.table_two_pk) + ) + + def test_group_keys_by_host(self): + """ + Test to ensure group_keys_by_host functions as expected. It is tried + with a table with a single field for the partition key and a table + with two fields for the partition key + @since 3.13 + @jira_ticket PYTHON-647 + @expected_result group_keys_by_host return the expected value + + @test_category metadata + """ + stmt = """SELECT * FROM {}.{} + WHERE k_one = ? AND k_two = ? """.format(self.ks_name, self.table_two_pk) + keys = ((1, 2), (2, 2), (2, 3), (3, 4)) + self._assert_group_keys_by_host(keys, self.table_two_pk, stmt) + + stmt = """SELECT * FROM {}.{} + WHERE k = ? """.format(self.ks_name, self.ks_name) + keys = ((1,), (2,), (2,), (3,)) + self._assert_group_keys_by_host(keys, self.ks_name, stmt) + + def _assert_group_keys_by_host(self, keys, table_name, stmt): + keys_per_host = group_keys_by_replica(self.session, self.ks_name, table_name, keys) + self.assertNotIn(NO_VALID_REPLICA, keys_per_host) + + prepared_stmt = self.session.prepare(stmt) + for key in keys: + routing_key = prepared_stmt.bind(key).routing_key + hosts = self.cluster.metadata.get_replicas(self.ks_name, routing_key) + self.assertEqual(1, len(hosts)) # RF is 1 for this keyspace + self.assertIn(key, keys_per_host[hosts[0]]) + + +class VirtualKeypaceTest(BasicSharedKeyspaceUnitTestCase): + virtual_ks_names = ('system_virtual_schema', 'system_views') + + def test_existing_keyspaces_have_correct_virtual_tags(self): + for name, ks in self.cluster.metadata.keyspaces.items(): + if name in self.virtual_ks_names: + self.assertTrue( + ks.virtual, + 'incorrect .virtual value for {}'.format(name) + ) + else: + self.assertFalse( + ks.virtual, + 'incorrect .virtual value for {}'.format(name) + ) + + @greaterthanorequalcass40 + @greaterthanorequaldse67 + def test_expected_keyspaces_exist_and_are_virtual(self): + for name in self.virtual_ks_names: + self.assertTrue( + self.cluster.metadata.keyspaces[name].virtual, + 'incorrect .virtual value for {}'.format(name) + ) + + @greaterthanorequalcass40 + @greaterthanorequaldse67 + def test_virtual_keyspaces_have_expected_schema_structure(self): + self.maxDiff = None + + ingested_virtual_ks_structure = defaultdict(dict) + for ks_name, ks in self.cluster.metadata.keyspaces.items(): + if not ks.virtual: + continue + for tab_name, tab in ks.tables.items(): + ingested_virtual_ks_structure[ks_name][tab_name] = set( + tab.columns.keys() + ) + + # Identify a couple known values to verify we parsed the structure correctly + self.assertIn('table_name', ingested_virtual_ks_structure['system_virtual_schema']['tables']) + self.assertIn('type', ingested_virtual_ks_structure['system_virtual_schema']['columns']) + self.assertIn('total', ingested_virtual_ks_structure['system_views']['sstable_tasks']) diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 48627b3f01..c33ea26573 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -14,27 +16,42 @@ import time -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +from cassandra.connection import ConnectionShutdown +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, FallthroughRetryPolicy + +import unittest from cassandra.query import SimpleStatement from cassandra import ConsistencyLevel, WriteTimeout, Unavailable, ReadTimeout +from cassandra.protocol import SyntaxException + +from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT +from tests.integration import get_cluster, get_node, use_singledc, execute_until_pass, TestCluster +from greplin import scales +from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local -from cassandra.cluster import Cluster, NoHostAvailable -from tests.integration import get_cluster, get_node, use_singledc, PROTOCOL_VERSION, execute_until_pass +import pprint as pp def setup_module(): use_singledc() - +@local class MetricsTests(unittest.TestCase): def setUp(self): - self.cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect("test3rf") + contact_point = ['127.0.0.2'] + self.cluster = TestCluster(contact_points=contact_point, metrics_enabled=True, + execution_profiles= + {EXEC_PROFILE_DEFAULT: + ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in contact_point), + retry_policy=FallthroughRetryPolicy() + ) + } + ) + self.session = self.cluster.connect("test3rf", wait_for_all_pools=True) def tearDown(self): self.cluster.shutdown() @@ -44,8 +61,6 @@ def test_connection_error(self): Trigger and ensure connection_errors are counted Stop all node with the driver knowing about the "DOWN" states. """ - - # Test writes for i in range(0, 100): self.session.execute_async("INSERT INTO test (k, v) VALUES ({0}, {1})".format(i, i)) @@ -56,7 +71,8 @@ def test_connection_error(self): try: # Ensure the nodes are actually down query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(NoHostAvailable): + # both exceptions can happen depending on when the connection has been detected as defunct + with self.assertRaises((NoHostAvailable, ConnectionShutdown)): self.session.execute(query) finally: get_cluster().start(wait_for_binary_proto=True, wait_other_notice=True) @@ -138,20 +154,22 @@ def test_unavailable(self): self.assertTrue(results) # Stop node gracefully + # Sometimes this commands continues with the other nodes having not noticed + # 1 is down, and a Timeout error is returned instead of an Unavailable get_node(1).stop(wait=True, wait_other_notice=True) - + time.sleep(5) try: # Test write query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query) - self.assertEqual(1, self.cluster.metrics.stats.unavailables) + self.assertEqual(self.cluster.metrics.stats.unavailables, 1) # Test write query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query, timeout=None) - self.assertEqual(2, self.cluster.metrics.stats.unavailables) + self.assertEqual(self.cluster.metrics.stats.unavailables, 2) finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) # Give some time for the cluster to come back up, for the next test @@ -170,3 +188,218 @@ def test_unavailable(self): # def test_retry(self): # # TODO: Look for ways to generate retries # pass + + +class MetricsNamespaceTest(BasicSharedKeyspaceUnitTestCaseRF3WM): + @local + def test_metrics_per_cluster(self): + """ + Test to validate that metrics can be scopped to invdividual clusters + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metrics should be scopped to a cluster level + + @test_category metrics + """ + + cluster2 = TestCluster( + metrics_enabled=True, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(retry_policy=FallthroughRetryPolicy())} + ) + cluster2.connect(self.ks_name, wait_for_all_pools=True) + + self.assertEqual(len(cluster2.metadata.all_hosts()), 3) + + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + self.session.execute(query) + + # Pause node so it shows as unreachable to coordinator + get_node(1).pause() + + try: + # Test write + query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(WriteTimeout): + self.session.execute(query, timeout=None) + finally: + get_node(1).resume() + + # Change the scales stats_name of the cluster2 + cluster2.metrics.set_stats_name('cluster2-metrics') + + stats_cluster1 = self.cluster.metrics.get_stats() + stats_cluster2 = cluster2.metrics.get_stats() + + # Test direct access to stats + self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) + self.assertEqual(0, cluster2.metrics.stats.write_timeouts) + + # Test direct access to a child stats + self.assertNotEqual(0.0, self.cluster.metrics.request_timer['mean']) + self.assertEqual(0.0, cluster2.metrics.request_timer['mean']) + + # Test access via metrics.get_stats() + self.assertNotEqual(0.0, stats_cluster1['request_timer']['mean']) + self.assertEqual(0.0, stats_cluster2['request_timer']['mean']) + + # Test access by stats_name + self.assertEqual(0.0, scales.getStats()['cluster2-metrics']['request_timer']['mean']) + + cluster2.shutdown() + + def test_duplicate_metrics_per_cluster(self): + """ + Test to validate that cluster metrics names can't overlap. + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metric names should not be allowed to be same. + + @test_category metrics + """ + cluster2 = TestCluster( + metrics_enabled=True, + monitor_reporting_enabled=False, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(retry_policy=FallthroughRetryPolicy())} + ) + + cluster3 = TestCluster( + metrics_enabled=True, + monitor_reporting_enabled=False, + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(retry_policy=FallthroughRetryPolicy())} + ) + + # Ensure duplicate metric names are not allowed + cluster2.metrics.set_stats_name("appcluster") + cluster2.metrics.set_stats_name("appcluster") + with self.assertRaises(ValueError): + cluster3.metrics.set_stats_name("appcluster") + cluster3.metrics.set_stats_name("devops") + + session2 = cluster2.connect(self.ks_name, wait_for_all_pools=True) + session3 = cluster3.connect(self.ks_name, wait_for_all_pools=True) + + # Basic validation that naming metrics doesn't impact their segregation or accuracy + for i in range(10): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session2.execute(query) + + for i in range(5): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session3.execute(query) + + self.assertEqual(cluster2.metrics.get_stats()['request_timer']['count'], 10) + self.assertEqual(cluster3.metrics.get_stats()['request_timer']['count'], 5) + + # Check scales to ensure they are appropriately named + self.assertTrue("appcluster" in scales._Stats.stats.keys()) + self.assertTrue("devops" in scales._Stats.stats.keys()) + + cluster2.shutdown() + cluster3.shutdown() + + +class RequestAnalyzer(object): + """ + Class used to track request and error counts for a Session. + Also computes statistics on encoded request size. + """ + + requests = scales.PmfStat('request size') + errors = scales.IntStat('errors') + successful = scales.IntStat("success") + # Throw exceptions when invoked. + throw_on_success = False + throw_on_fail = False + + def __init__(self, session, throw_on_success=False, throw_on_fail=False): + scales.init(self, '/request') + # each instance will be registered with a session, and receive a callback for each request generated + session.add_request_init_listener(self.on_request) + self.throw_on_fail = throw_on_fail + self.throw_on_success = throw_on_success + + def on_request(self, rf): + # This callback is invoked each time a request is created, on the thread creating the request. + # We can use this to count events, or add callbacks + rf.add_callbacks(self.on_success, self.on_error, callback_args=(rf,), errback_args=(rf,)) + + def on_success(self, _, response_future): + # future callback on a successful request; just record the size + self.requests.addValue(response_future.request_encoded_size) + self.successful += 1 + if self.throw_on_success: + raise AttributeError + + def on_error(self, _, response_future): + # future callback for failed; record size and increment errors + self.requests.addValue(response_future.request_encoded_size) + self.errors += 1 + if self.throw_on_fail: + raise AttributeError + + def remove_ra(self, session): + session.remove_request_init_listener(self.on_request) + + def __str__(self): + # just extracting request count from the size stats (which are recorded on all requests) + request_sizes = dict(self.requests) + count = request_sizes.pop('count') + return "%d requests (%d errors)\nRequest size statistics:\n%s" % (count, self.errors, pp.pformat(request_sizes)) + + +class MetricsRequestSize(BasicExistingKeyspaceUnitTestCase): + + @classmethod + def setUpClass(cls): + cls.common_setup(1, keyspace_creation=False, monitor_reporting_enabled=False) + + def wait_for_count(self, ra, expected_count, error=False): + for _ in range(10): + if not error: + if ra.successful is expected_count: + return True + else: + if ra.errors is expected_count: + return True + time.sleep(.01) + return False + + def test_metrics_per_cluster(self): + """ + Test to validate that requests listeners. + + This test creates a simple metrics based request listener to track request size, it then + check to ensure that on_success and on_error methods are invoked appropriately. + @since 3.7.0 + @jira_ticket PYTHON-284 + @expected_result in_error, and on_success should be invoked appropriately + + @test_category metrics + """ + + ra = RequestAnalyzer(self.session) + for _ in range(10): + self.session.execute("SELECT release_version FROM system.local") + + for _ in range(3): + try: + self.session.execute("nonsense") + except SyntaxException: + continue + + self.assertTrue(self.wait_for_count(ra, 10)) + self.assertTrue(self.wait_for_count(ra, 3, error=True)) + + ra.remove_ra(self.session) + + # Make sure a poorly coded RA doesn't cause issues + ra = RequestAnalyzer(self.session, throw_on_success=False, throw_on_fail=True) + self.session.execute("SELECT release_version FROM system.local") + + ra.remove_ra(self.session) + + RequestAnalyzer(self.session, throw_on_success=True) + try: + self.session.execute("nonsense") + except SyntaxException: + pass diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py new file mode 100644 index 0000000000..bb69243212 --- /dev/null +++ b/tests/integration/standard/test_policies.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, SimpleConvictionPolicy, \ + WhiteListRoundRobinPolicy +from cassandra.pool import Host +from cassandra.connection import DefaultEndPoint + +from tests.integration import local, use_singledc, TestCluster + +from concurrent.futures import wait as wait_futures + +def setup_module(): + use_singledc() + +class HostFilterPolicyTests(unittest.TestCase): + + def test_predicate_changes(self): + """ + Test to validate host filter reacts correctly when the predicate return + a different subset of the hosts + HostFilterPolicy + @since 3.8 + @jira_ticket PYTHON-961 + @expected_result the excluded hosts are ignored + + @test_category policy + """ + external_event = True + contact_point = DefaultEndPoint("127.0.0.1") + + single_host = {Host(contact_point, SimpleConvictionPolicy)} + all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)} + + predicate = lambda host: host.endpoint == contact_point if external_event else True + hfp = ExecutionProfile( + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate) + ) + cluster = TestCluster(contact_points=(contact_point,), execution_profiles={EXEC_PROFILE_DEFAULT: hfp}, + topology_event_refresh_window=0, + status_event_refresh_window=0) + session = cluster.connect(wait_for_all_pools=True) + + queried_hosts = set() + for _ in range(10): + response = session.execute("SELECT * from system.local") + queried_hosts.update(response.response_future.attempted_hosts) + + self.assertEqual(queried_hosts, single_host) + + external_event = False + futures = session.update_created_pools() + wait_futures(futures, timeout=cluster.connect_timeout) + + queried_hosts = set() + for _ in range(10): + response = session.execute("SELECT * from system.local") + queried_hosts.update(response.response_future.attempted_hosts) + self.assertEqual(queried_hosts, all_hosts) + + +class WhiteListRoundRobinPolicyTests(unittest.TestCase): + + @local + def test_only_connects_to_subset(self): + only_connect_hosts = {"127.0.0.1", "127.0.0.2"} + white_list = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + cluster = TestCluster(execution_profiles={"white_list": white_list}) + #cluster = Cluster(load_balancing_policy=WhiteListRoundRobinPolicy(only_connect_hosts)) + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + response = session.execute('SELECT * from system.local', execution_profile="white_list") + queried_hosts.update(response.response_future.attempted_hosts) + queried_hosts = set(host.address for host in queried_hosts) + self.assertEqual(queried_hosts, only_connect_hosts) diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 63f62f238e..429aa0efc7 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,18 +15,23 @@ # limitations under the License. -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster, CASSANDRA_VERSION + +import unittest + +from packaging.version import Version -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa -from cassandra import InvalidRequest +from cassandra import InvalidRequest, DriverException -from cassandra import ConsistencyLevel -from cassandra.cluster import Cluster +from cassandra import ConsistencyLevel, ProtocolVersion from cassandra.query import PreparedStatement, UNSET_VALUE -from tests.integration import get_server_versions +from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50, + requirecassandra, BasicSharedKeyspaceUnitTestCase) + +import logging + + +LOG = logging.getLogger(__name__) def setup_module(): @@ -38,7 +45,7 @@ def setUpClass(cls): cls.cass_version = get_server_versions() def setUp(self): - self.cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster(metrics_enabled=True, allow_beta_protocol_version=True) self.session = self.cluster.connect() def tearDown(self): @@ -48,7 +55,11 @@ def test_basic(self): """ Test basic PreparedStatement usage """ - + self.session.execute( + """ + DROP KEYSPACE IF EXISTS preparedtests + """ + ) self.session.execute( """ CREATE KEYSPACE preparedtests @@ -232,7 +243,7 @@ def test_unset_values(self): @since 2.6.0 @jira_ticket PYTHON-317 - @expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leving unset values unaffected. + @expected_result UNSET_VALUE is implicitly added to bind parameters, and properly encoded, leaving unset values unaffected. @test_category prepared_statements:binding """ @@ -385,3 +396,241 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self): with self.assertRaises(InvalidRequest): self.session.execute(prepared, [0]) + + @unittest.skipIf((CASSANDRA_VERSION >= Version('3.11.12') and CASSANDRA_VERSION < Version('4.0')) or \ + CASSANDRA_VERSION >= Version('4.0.2'), + "Fixed server-side in Cassandra 3.11.12, 4.0.2") + def test_fail_if_different_query_id_on_reprepare(self): + """ PYTHON-1124 and CASSANDRA-15252 """ + keyspace = "test_fail_if_different_query_id_on_reprepare" + self.session.execute( + "CREATE KEYSPACE IF NOT EXISTS {} WITH replication = " + "{{'class': 'SimpleStrategy', 'replication_factor': 1}}".format(keyspace) + ) + self.session.execute("CREATE TABLE IF NOT EXISTS {}.foo(k int PRIMARY KEY)".format(keyspace)) + prepared = self.session.prepare("SELECT * FROM {}.foo WHERE k=?".format(keyspace)) + self.session.execute("DROP TABLE {}.foo".format(keyspace)) + self.session.execute("CREATE TABLE {}.foo(k int PRIMARY KEY)".format(keyspace)) + self.session.execute("USE {}".format(keyspace)) + with self.assertRaises(DriverException) as e: + self.session.execute(prepared, [0]) + self.assertIn("ID mismatch", str(e.exception)) + + +@greaterthanorequalcass40 +class PreparedStatementInvalidationTest(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + self.table_name = "{}.prepared_statement_invalidation_test".format(self.keyspace_name) + self.session.execute("CREATE TABLE {} (a int PRIMARY KEY, b int, d int);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (1, 1, 1);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (2, 2, 2);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (3, 3, 3);".format(self.table_name)) + self.session.execute("INSERT INTO {} (a, b, d) VALUES (4, 4, 4);".format(self.table_name)) + + def tearDown(self): + self.session.execute("DROP TABLE {}".format(self.table_name)) + + def test_invalidated_result_metadata(self): + """ + Tests to make sure cached metadata is updated when an invalidated prepared statement is reprepared. + + @since 2.7.0 + @jira_ticket PYTHON-621 + + Prior to this fix, the request would blow up with a protocol error when the result was decoded expecting a different + number of columns. + """ + wildcard_prepared = self.session.prepare("SELECT * FROM {}".format(self.table_name)) + original_result_metadata = wildcard_prepared.result_metadata + self.assertEqual(len(original_result_metadata), 3) + + r = self.session.execute(wildcard_prepared) + self.assertEqual(r[0], (1, 1, 1)) + + self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) + + # Get a bunch of requests in the pipeline with varying states of result_meta, reprepare, resolved + futures = set(self.session.execute_async(wildcard_prepared.bind(None)) for _ in range(200)) + for f in futures: + self.assertEqual(f.result()[0], (1, 1)) + + self.assertIsNot(wildcard_prepared.result_metadata, original_result_metadata) + + def test_prepared_id_is_update(self): + """ + Tests that checks the query id from the prepared statement + is updated properly if the table that the prepared statement is querying + is altered. + + @since 3.12 + @jira_ticket PYTHON-808 + + The query id from the prepared statement must have changed + """ + prepared_statement = self.session.prepare("SELECT * from {} WHERE a = ?".format(self.table_name)) + id_before = prepared_statement.result_metadata_id + self.assertEqual(len(prepared_statement.result_metadata), 3) + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + bound_statement = prepared_statement.bind((1, )) + self.session.execute(bound_statement, timeout=1) + + id_after = prepared_statement.result_metadata_id + + self.assertNotEqual(id_before, id_after) + self.assertEqual(len(prepared_statement.result_metadata), 4) + + def test_prepared_id_is_updated_across_pages(self): + """ + Test that checks that the query id from the prepared statement + is updated if the table hat the prepared statement is querying + is altered while fetching pages in a single query. + Then it checks that the updated rows have the expected result. + + @since 3.12 + @jira_ticket PYTHON-808 + """ + prepared_statement = self.session.prepare("SELECT * from {}".format(self.table_name)) + id_before = prepared_statement.result_metadata_id + self.assertEqual(len(prepared_statement.result_metadata), 3) + + prepared_statement.fetch_size = 2 + result = self.session.execute(prepared_statement.bind((None))) + + self.assertTrue(result.has_more_pages) + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + + result_set = set(x for x in ((1, 1, 1), (2, 2, 2), (3, 3, None, 3), (4, 4, None, 4))) + expected_result_set = set(row for row in result) + + id_after = prepared_statement.result_metadata_id + + self.assertEqual(result_set, expected_result_set) + self.assertNotEqual(id_before, id_after) + self.assertEqual(len(prepared_statement.result_metadata), 4) + + def test_prepare_id_is_updated_across_session(self): + """ + Test that checks that the query id from the prepared statement + is updated if the table hat the prepared statement is querying + is altered by a different session + + @since 3.12 + @jira_ticket PYTHON-808 + """ + one_cluster = TestCluster(metrics_enabled=True) + one_session = one_cluster.connect() + self.addCleanup(one_cluster.shutdown) + + stm = "SELECT * from {} WHERE a = ?".format(self.table_name) + one_prepared_stm = one_session.prepare(stm) + self.assertEqual(len(one_prepared_stm.result_metadata), 3) + + one_id_before = one_prepared_stm.result_metadata_id + + self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + one_session.execute(one_prepared_stm, (1, )) + + one_id_after = one_prepared_stm.result_metadata_id + self.assertNotEqual(one_id_before, one_id_after) + self.assertEqual(len(one_prepared_stm.result_metadata), 4) + + def test_not_reprepare_invalid_statements(self): + """ + Test that checks that an InvalidRequest is arisen if a column + expected by the prepared statement is dropped. + + @since 3.12 + @jira_ticket PYTHON-808 + """ + prepared_statement = self.session.prepare( + "SELECT a, b, d FROM {} WHERE a = ?".format(self.table_name)) + self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) + with self.assertRaises(InvalidRequest): + self.session.execute(prepared_statement.bind((1, ))) + + def test_id_is_not_updated_conditional_v4(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are updated correctly in conditional statements + in protocol V4 + + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = TestCluster(protocol_version=ProtocolVersion.V4) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 9) + + @requirecassandra + def test_id_is_not_updated_conditional_v5(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are updated correctly in conditional statements + in protocol V5 + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = TestCluster(protocol_version=ProtocolVersion.V5) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 10) + + @greaterthanorequaldse50 + def test_id_is_not_updated_conditional_dsev1(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are updated correctly in conditional statements + in protocol DSE V1 + + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = TestCluster(protocol_version=ProtocolVersion.DSE_V1) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 10) + + @greaterthanorequaldse50 + def test_id_is_not_updated_conditional_dsev2(self): + """ + Test that verifies that the result_metadata and the + result_metadata_id are updated correctly in conditional statements + in protocol DSE V2 + + @since 3.13 + @jira_ticket PYTHON-847 + """ + cluster = TestCluster(protocol_version=ProtocolVersion.DSE_V2) + session = cluster.connect() + self.addCleanup(cluster.shutdown) + self._test_updated_conditional(session, 10) + + def _test_updated_conditional(self, session, value): + prepared_statement = session.prepare( + "INSERT INTO {}(a, b, d) VALUES " + "(?, ? , ?) IF NOT EXISTS".format(self.table_name)) + first_id = prepared_statement.result_metadata_id + LOG.debug('initial result_metadata_id: {}'.format(first_id)) + + def check_result_and_metadata(expected): + self.assertEqual( + session.execute(prepared_statement, (value, value, value))[0], + expected + ) + self.assertEqual(prepared_statement.result_metadata_id, first_id) + self.assertIsNone(prepared_statement.result_metadata) + + # Successful conditional update + check_result_and_metadata((True,)) + + # Failed conditional update + check_result_and_metadata((False, value, value, value)) + + session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) + + # Failed conditional update + check_result_and_metadata((False, value, value, None, value)) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 51ab1f514d..3ede0ac326 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,29 +15,44 @@ # limitations under the License. import os from cassandra.concurrent import execute_concurrent +from cassandra import DriverException - -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from cassandra import ConsistencyLevel +import unittest +import logging +from cassandra import ProtocolVersion +from cassandra import ConsistencyLevel, Unavailable, InvalidRequest, cluster from cassandra.query import (PreparedStatement, BoundStatement, SimpleStatement, BatchStatement, BatchType, dict_factory, TraceUnavailable) -from cassandra.cluster import Cluster -from cassandra.policies import HostDistance - -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3 +from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, Cluster +from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \ + greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ + USE_CASS_EXTERNAL, greaterthanorequalcass40, DSE_VERSION, TestCluster, requirecassandra +from tests import notwindows +from tests.integration import greaterthanorequalcass30, get_node import time +import random import re +from unittest import mock + + +log = logging.getLogger(__name__) + def setup_module(): - use_singledc() - global CASS_SERVER_VERSION - CASS_SERVER_VERSION = get_server_versions()[0] + if not USE_CASS_EXTERNAL: + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + # This is necessary because test_too_many_statements may + # timeout otherwise + config_options = {'write_request_timeout_in_ms': '20000'} + ccm_cluster.set_configuration_options(config_options) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + setup_keyspace() class QueryTests(BasicSharedKeyspaceUnitTestCase): @@ -70,6 +87,22 @@ def test_trace_prints_okay(self): for event in trace.events: str(event) + def test_row_error_message(self): + """ + Test to validate, new column deserialization message + @since 3.7.0 + @jira_ticket PYTHON-361 + @expected_result Special failed decoding message should be present + + @test_category tracing + """ + self.session.execute("CREATE TABLE {0}.{1} (k int PRIMARY KEY, v timestamp)".format(self.keyspace_name,self.function_table_name)) + ss = SimpleStatement("INSERT INTO {0}.{1} (k, v) VALUES (1, 1000000000000000)".format(self.keyspace_name, self.function_table_name)) + self.session.execute(ss) + with self.assertRaises(DriverException) as context: + self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.assertIn("Failed decoding result column", str(context.exception)) + def test_trace_id_to_resultset(self): future = self.session.execute_async("SELECT * FROM system.local", trace=True) @@ -87,19 +120,22 @@ def test_trace_id_to_resultset(self): self.assertListEqual([rs_trace], rs.get_all_query_traces()) def test_trace_ignores_row_factory(self): - self.session.row_factory = dict_factory - - query = "SELECT * FROM system.local" - statement = SimpleStatement(query) - rs = self.session.execute(statement, trace=True) - - # Ensure this does not throw an exception - trace = rs.get_query_trace() - self.assertTrue(trace.events) - str(trace) - for event in trace.events: - str(event) - + with TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) as cluster: + s = cluster.connect() + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + rs = s.execute(statement, trace=True) + + # Ensure this does not throw an exception + trace = rs.get_query_trace() + self.assertTrue(trace.events) + str(trace) + for event in trace.events: + str(event) + + @local @greaterthanprotocolv3 def test_client_ip_in_trace(self): """ @@ -109,7 +145,7 @@ def test_client_ip_in_trace(self): only be the case if the c* version is 2.2 or greater @since 2.6.0 - @jira_ticket PYTHON-235 + @jira_ticket PYTHON-435 @expected_result client address should be present in C* >= 2.2, otherwise should be none. @test_category tracing @@ -128,7 +164,7 @@ def test_client_ip_in_trace(self): response_future.result() # Fetch the client_ip from the trace. - trace = response_future.get_query_trace(max_wait=2.0) + trace = response_future.get_query_trace(max_wait=10.0) client_ip = trace.client # Ip address should be in the local_host range @@ -138,6 +174,33 @@ def test_client_ip_in_trace(self): self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2") self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value") + def test_trace_cl(self): + """ + Test to ensure that CL is set correctly honored when executing trace queries. + + @since 3.3 + @jira_ticket PYTHON-435 + @expected_result Consistency Levels set on get_query_trace should be honored + """ + # Execute a query + query = "SELECT * FROM system.local" + statement = SimpleStatement(query) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(Unavailable): + response_future.get_query_trace(query_cl=ConsistencyLevel.THREE) + # Try again with a smattering of other CL's + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id) + response_future = self.session.execute_async(statement, trace=True) + response_future.result() + with self.assertRaises(InvalidRequest): + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id) + self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id) + + @notwindows def test_incomplete_query_trace(self): """ Tests to ensure that partial tracing works. @@ -165,7 +228,7 @@ def test_incomplete_query_trace(self): self.assertTrue(self._wait_for_trace_to_populate(trace.trace_id)) # Delete trace duration from the session (this is what the driver polls for "complete") - delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) + delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) self.session.execute(delete_statement) self.assertTrue(self._wait_for_trace_to_delete(trace.trace_id)) @@ -199,12 +262,40 @@ def _wait_for_trace_to_delete(self, trace_id): return count != retry_max def _is_trace_present(self, trace_id): - select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {}".format(trace_id), consistency_level=ConsistencyLevel.ALL) + select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format(trace_id), consistency_level=ConsistencyLevel.ALL) ssrs = self.session.execute(select_statement) - if(ssrs[0].duration is None): + if not len(ssrs.current_rows) or ssrs[0].duration is None: return False return True + def test_query_by_id(self): + """ + Test to ensure column_types are set as part of the result set + + @since 3.8 + @jira_ticket PYTHON-648 + @expected_result column_names should be preset. + + @test_category queries basic + """ + create_table = "CREATE TABLE {0}.{1} (id int primary key, m map)".format(self.keyspace_name, self.function_table_name) + self.session.execute(create_table) + + self.session.execute("insert into "+self.keyspace_name+"."+self.function_table_name+" (id, m) VALUES ( 1, {1: 'one', 2: 'two', 3:'three'})") + results1 = self.session.execute("select id, m from {0}.{1}".format(self.keyspace_name, self.function_table_name)) + + self.assertIsNotNone(results1.column_types) + self.assertEqual(results1.column_types[0].typename, 'int') + self.assertEqual(results1.column_types[1].typename, 'map') + self.assertEqual(results1.column_types[0].cassname, 'Int32Type') + self.assertEqual(results1.column_types[1].cassname, 'MapType') + self.assertEqual(len(results1.column_types[0].subtypes), 0) + self.assertEqual(len(results1.column_types[1].subtypes), 2) + self.assertEqual(results1.column_types[1].subtypes[0].typename, "int") + self.assertEqual(results1.column_types[1].subtypes[1].typename, "varchar") + self.assertEqual(results1.column_types[1].subtypes[0].cassname, "Int32Type") + self.assertEqual(results1.column_types[1].subtypes[1].cassname, "VarcharType") + def test_column_names(self): """ Test to validate the columns are present on the result set. @@ -225,15 +316,56 @@ def test_column_names(self): score INT, PRIMARY KEY (user, game, year, month, day) )""".format(self.keyspace_name, self.function_table_name) + + self.session.execute(create_table) result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) + self.assertIsNotNone(result_set.column_types) + self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score']) + @greaterthanorequalcass30 + def test_basic_json_query(self): + insert_query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)", consistency_level = ConsistencyLevel.QUORUM) + json_query = SimpleStatement("SELECT JSON * FROM test3rf.test where k=1", consistency_level = ConsistencyLevel.QUORUM) + + self.session.execute(insert_query) + results = self.session.execute(json_query) + self.assertEqual(results.column_names, ["[json]"]) + self.assertEqual(results[0][0], '{"k": 1, "v": 1}') + + def test_host_targeting_query(self): + """ + Test to validate the the single host targeting works. + + @since 3.17.0 + @jira_ticket PYTHON-933 + @expected_result the coordinator host is always the one set + """ + + default_ep = self.cluster.profile_manager.default + # copy of default EP with checkable LBP + checkable_ep = self.session.execution_profile_clone_update( + ep=default_ep, + load_balancing_policy=mock.Mock(wraps=default_ep.load_balancing_policy) + ) + query = SimpleStatement("INSERT INTO test3rf.test(k, v) values (1, 1)") + + for i in range(10): + host = random.choice(self.cluster.metadata.all_hosts()) + log.debug('targeting {}'.format(host)) + future = self.session.execute_async(query, host=host, execution_profile=checkable_ep) + future.result() + # check we're using the selected host + self.assertEqual(host, future.coordinator_host) + # check that this bypasses the LBP + self.assertFalse(checkable_ep.load_balancing_policy.make_query_plan.called) + class PreparedStatementTests(unittest.TestCase): def setUp(self): - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() self.session = self.cluster.connect() def tearDown(self): @@ -314,6 +446,185 @@ def test_bound_keyspace(self): self.assertEqual(bound.keyspace, 'test3rf') +class ForcedHostIndexPolicy(RoundRobinPolicy): + def __init__(self, host_index_to_use=0): + super(ForcedHostIndexPolicy, self).__init__() + self.host_index_to_use = host_index_to_use + + def set_host(self, host_index): + """ 0-based index of which host to use """ + self.host_index_to_use = host_index + + def make_query_plan(self, working_keyspace=None, query=None): + live_hosts = sorted(list(self._live_hosts)) + host = [] + try: + host = [live_hosts[self.host_index_to_use]] + except IndexError as e: + raise IndexError( + 'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format( + len(live_hosts), self.host_index_to_use + )) from e + return host + + +class PreparedStatementMetdataTest(unittest.TestCase): + + def test_prepared_metadata_generation(self): + """ + Test to validate that result metadata is appropriately populated across protocol version + + In protocol version 1 result metadata is retrieved everytime the statement is issued. In all + other protocol versions it's set once upon the prepare, then re-used. This test ensures that it manifests + it's self the same across multiple protocol versions. + + @since 3.6.0 + @jira_ticket PYTHON-71 + @expected_result result metadata is consistent. + """ + + base_line = None + for proto_version in get_supported_protocol_versions(): + beta_flag = True if proto_version in ProtocolVersion.BETA_VERSIONS else False + cluster = Cluster(protocol_version=proto_version, allow_beta_protocol_version=beta_flag) + + session = cluster.connect() + select_statement = session.prepare("SELECT * FROM system.local") + if proto_version == 1: + self.assertEqual(select_statement.result_metadata, None) + else: + self.assertNotEqual(select_statement.result_metadata, None) + future = session.execute_async(select_statement) + results = future.result() + if base_line is None: + base_line = results[0]._asdict().keys() + else: + self.assertEqual(base_line, results[0]._asdict().keys()) + cluster.shutdown() + + +class PreparedStatementArgTest(unittest.TestCase): + + def setUp(self): + self.mock_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(self.mock_handler) + + def test_prepare_on_all_hosts(self): + """ + Test to validate prepare_on_all_hosts flag is honored. + + Force the host of each query to ensure prepared queries are cycled over nodes that should not + have them prepared. Check the logs to insure they are being re-prepared on those nodes + + @since 3.4.0 + @jira_ticket PYTHON-556 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + """ + clus = TestCluster(prepare_on_all_hosts=False, reprepare_on_up=False) + self.addCleanup(clus.shutdown) + + session = clus.connect(wait_for_all_pools=True) + select_statement = session.prepare("SELECT k FROM test3rf.test WHERE k = ?") + for host in clus.metadata.all_hosts(): + session.execute(select_statement, (1, ), host=host) + self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing")) + + def test_prepare_batch_statement(self): + """ + Test to validate a prepared statement used inside a batch statement is correctly handled + by the driver + + @since 3.10 + @jira_ticket PYTHON-706 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + and the batch statement will be sent. + """ + policy = ForcedHostIndexPolicy() + clus = TestCluster( + execution_profiles={ + EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy=policy), + }, + prepare_on_all_hosts=False, + reprepare_on_up=False, + ) + self.addCleanup(clus.shutdown) + + table = "test3rf.%s" % self._testMethodName.lower() + + session = clus.connect(wait_for_all_pools=True) + + session.execute("DROP TABLE IF EXISTS %s" % table) + session.execute("CREATE TABLE %s (k int PRIMARY KEY, v int )" % table) + + insert_statement = session.prepare("INSERT INTO %s (k, v) VALUES (?, ?)" % table) + + # This is going to query a host where the query + # is not prepared + policy.set_host(1) + batch_statement = BatchStatement(consistency_level=ConsistencyLevel.ONE) + batch_statement.add(insert_statement, (1, 2)) + session.execute(batch_statement) + + # To verify our test assumption that queries are getting re-prepared properly + self.assertEqual(1, self.mock_handler.get_message_count('debug', "Re-preparing")) + + select_results = session.execute(SimpleStatement("SELECT * FROM %s WHERE k = 1" % table, + consistency_level=ConsistencyLevel.ALL)) + first_row = select_results[0][:2] + self.assertEqual((1, 2), first_row) + + def test_prepare_batch_statement_after_alter(self): + """ + Test to validate a prepared statement used inside a batch statement is correctly handled + by the driver. The metadata might be updated when a table is altered. This tests combines + queries not being prepared and an update of the prepared statement metadata + + @since 3.10 + @jira_ticket PYTHON-706 + @expected_result queries will have to re-prepared on hosts that aren't the control connection + and the batch statement will be sent. + """ + clus = TestCluster(prepare_on_all_hosts=False, reprepare_on_up=False) + self.addCleanup(clus.shutdown) + + table = "test3rf.%s" % self._testMethodName.lower() + + session = clus.connect(wait_for_all_pools=True) + + session.execute("DROP TABLE IF EXISTS %s" % table) + session.execute("CREATE TABLE %s (k int PRIMARY KEY, a int, b int, d int)" % table) + insert_statement = session.prepare("INSERT INTO %s (k, b, d) VALUES (?, ?, ?)" % table) + + # Altering the table might trigger an update in the insert metadata + session.execute("ALTER TABLE %s ADD c int" % table) + + values_to_insert = [(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] + + # We query the three hosts in order (due to the ForcedHostIndexPolicy) + # the first three queries will have to be repreapred and the rest should + # work as normal batch prepared statements + hosts = clus.metadata.all_hosts() + for i in range(10): + value_to_insert = values_to_insert[i % len(values_to_insert)] + batch_statement = BatchStatement(consistency_level=ConsistencyLevel.ONE) + batch_statement.add(insert_statement, value_to_insert) + session.execute(batch_statement, host=hosts[i % len(hosts)]) + + select_results = session.execute("SELECT * FROM %s" % table) + expected_results = [ + (1, None, 2, None, 3), + (2, None, 3, None, 4), + (3, None, 4, None, 5), + (4, None, 5, None, 6) + ] + + self.assertEqual(set(expected_results), set(select_results._current_rows)) + + # To verify our test assumption that queries are getting re-prepared properly + self.assertEqual(3, self.mock_handler.get_message_count('debug', "Re-preparing")) + + class PrintStatementTests(unittest.TestCase): """ Test that shows the format used when printing Statements @@ -333,7 +644,7 @@ def test_prepared_statement(self): Highlight the difference between Prepared and Bound statements """ - cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cluster = TestCluster() session = cluster.connect() prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') @@ -357,10 +668,10 @@ def setUp(self): "Protocol 2.0+ is required for BATCH operations, currently testing against %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() if PROTOCOL_VERSION < 3: self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) - self.session = self.cluster.connect() + self.session = self.cluster.connect(wait_for_all_pools=True) def tearDown(self): self.cluster.shutdown() @@ -445,11 +756,6 @@ def test_no_parameters(self): self.session.execute(batch) self.confirm_results() - def test_no_parameters_many_times(self): - for i in range(1000): - self.test_no_parameters() - self.session.execute("TRUNCATE test3rf.test") - def test_unicode(self): ddl = ''' CREATE TABLE test3rf.testtext ( @@ -465,6 +771,26 @@ def test_unicode(self): finally: self.session.execute("DROP TABLE test3rf.testtext") + def test_too_many_statements(self): + # The actual max # of statements is 0xFFFF, but this can occasionally cause a server write timeout. + large_batch = 0xFFF + max_statements = 0xFFFF + ss = SimpleStatement("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") + b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + + # large number works works + b.add_all([ss] * large_batch, [None] * large_batch) + self.session.execute(b, timeout=30.0) + + b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) + # max + 1 raises + b.add_all([ss] * max_statements, [None] * max_statements) + self.assertRaises(ValueError, b.add, ss) + + # also would have bombed trying to encode + b._statements_and_parameters.append((False, ss.query_string, ())) + self.assertRaises(NoHostAvailable, self.session.execute, b) + class SerialConsistencyTests(unittest.TestCase): def setUp(self): @@ -473,7 +799,7 @@ def setUp(self): "Protocol 2.0+ is required for Serial Consistency, currently testing against %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() if PROTOCOL_VERSION < 3: self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) self.session = self.cluster.connect() @@ -564,7 +890,8 @@ def setUp(self): "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + serial_profile = ExecutionProfile(consistency_level=ConsistencyLevel.SERIAL) + self.cluster = TestCluster(execution_profiles={'serial': serial_profile}) self.session = self.cluster.connect() ddl = ''' @@ -573,11 +900,20 @@ def setUp(self): v int )''' self.session.execute(ddl) + ddl = ''' + CREATE TABLE test3rf.lwt_clustering ( + k int, + c int, + v int, + PRIMARY KEY (k, c))''' + self.session.execute(ddl) + def tearDown(self): """ Shutdown cluster """ self.session.execute("DROP TABLE test3rf.lwt") + self.session.execute("DROP TABLE test3rf.lwt_clustering") self.cluster.shutdown() def test_no_connection_refused_on_timeout(self): @@ -605,24 +941,132 @@ def test_no_connection_refused_on_timeout(self): continue else: # In this case result is an exception - if type(result).__name__ == "NoHostAvailable": + exception_type = type(result).__name__ + if exception_type == "NoHostAvailable": self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) - if type(result).__name__ == "WriteTimeout": - received_timeout = True - continue - if type(result).__name__ == "WriteFailure": - received_timeout = True - continue - if type(result).__name__ == "ReadTimeout": - continue - if type(result).__name__ == "ReadFailure": + if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub", "ErrorMessage"]: + if type(result).__name__ in ["WriteTimeout", "WriteFailure"]: + received_timeout = True continue - self.fail("Unexpected exception %s: %s" % (type(result).__name__, result.message)) + self.fail("Unexpected exception %s: %s" % (exception_type, result.message)) # Make sure test passed self.assertTrue(received_timeout) + def test_was_applied_batch_stmt(self): + """ + Test to ensure `:attr:cassandra.cluster.ResultSet.was_applied` works as expected + with Batchstatements. + + For both type of batches verify was_applied has the correct result + under different scenarios: + - If on LWT fails the rest of the statements fail including normal UPSERTS + - If on LWT fails the rest of the statements fail + - All the queries succeed + + @since 3.14 + @jira_ticket PYTHON-848 + @expected_result `:attr:cassandra.cluster.ResultSet.was_applied` is updated as + expected + + @test_category query + """ + for batch_type in (BatchType.UNLOGGED, BatchType.LOGGED): + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + result = self.session.execute(batch_statement) + #self.assertTrue(result.was_applied) + + # Should fail since (0, 0, 10) have already been written + # The non conditional insert shouldn't be written as well + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 4) + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + # Verify the non conditional insert hasn't been inserted + self.assertEqual(len(all_rows.current_rows), 3) + + # Should fail since (0, 0, 10) have already been written + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + # Should fail since (0, 0, 10) have already been written + batch_statement.add("INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;") + result = self.session.execute(batch_statement) + self.assertFalse(result.was_applied) + + # Should succeed + batch_statement = BatchStatement(batch_type) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS;", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) + + result = self.session.execute(batch_statement) + self.assertTrue(result.was_applied) + + all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') + for i, row in enumerate(all_rows): + self.assertEqual((0, i, 10), (row[0], row[1], row[2])) + + self.session.execute("TRUNCATE TABLE test3rf.lwt_clustering") + + def test_empty_batch_statement(self): + """ + Test to ensure `:attr:cassandra.cluster.ResultSet.was_applied` works as expected + with empty Batchstatements. + + @since 3.14 + @jira_ticket PYTHON-848 + @expected_result an Exception is raised + expected + + @test_category query + """ + batch_statement = BatchStatement() + results = self.session.execute(batch_statement) + with self.assertRaises(RuntimeError): + results.was_applied + + @unittest.skip("Skipping until PYTHON-943 is resolved") + def test_was_applied_batch_string(self): + batch_statement = BatchStatement(BatchType.LOGGED) + batch_statement.add_all(["INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", + "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) + self.session.execute(batch_statement) + + batch_str = """ + BEGIN unlogged batch + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10) IF NOT EXISTS; + APPLY batch; + """ + result = self.session.execute(batch_str) + self.assertFalse(result.was_applied) + + batch_str = """ + BEGIN unlogged batch + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10) IF NOT EXISTS; + INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS; + APPLY batch; + """ + result = self.session.execute(batch_str) + self.assertTrue(result.was_applied) + class BatchStatementDefaultRoutingKeyTests(unittest.TestCase): # Test for PYTHON-126: BatchStatement.add() should set the routing key of the first added prepared statement @@ -632,7 +1076,7 @@ def setUp(self): raise unittest.SkipTest( "Protocol 2.0+ is required for BATCH operations, currently testing against %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster() self.session = self.cluster.connect() query = """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) @@ -709,12 +1153,9 @@ def test_inherit_first_rk_prepared_param(self): self.assertEqual(batch.routing_key, self.prepared.bind((1, 0)).routing_key) +@greaterthanorequalcass30 class MaterializedViewQueryTest(BasicSharedKeyspaceUnitTestCase): - def setUp(self): - if CASS_SERVER_VERSION < (3, 0): - raise unittest.SkipTest("Materialized views require Cassandra 3.0+") - def test_mv_filtering(self): """ Test to ensure that cql filtering where clauses are properly supported in the python driver. @@ -744,31 +1185,36 @@ def test_mv_filtering(self): SELECT * FROM {0}.scores WHERE game IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) create_mv_dailyhigh = """CREATE MATERIALIZED VIEW {0}.dailyhigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND day IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL PRIMARY KEY ((game, year, month, day), score, user) - WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC)""".format(self.keyspace_name) create_mv_monthlyhigh = """CREATE MATERIALIZED VIEW {0}.monthlyhigh AS SELECT * FROM {0}.scores WHERE game IS NOT NULL AND year IS NOT NULL AND month IS NOT NULL AND score IS NOT NULL AND user IS NOT NULL AND day IS NOT NULL PRIMARY KEY ((game, year, month), score, user, day) - WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) create_mv_filtereduserhigh = """CREATE MATERIALIZED VIEW {0}.filtereduserhigh AS SELECT * FROM {0}.scores WHERE user in ('jbellis', 'pcmanus') AND game IS NOT NULL AND score IS NOT NULL AND year is NOT NULL AND day is not NULL and month IS NOT NULL PRIMARY KEY (game, score, user, year, month, day) - WITH CLUSTERING ORDER BY (score DESC)""".format(self.keyspace_name) + WITH CLUSTERING ORDER BY (score DESC, user ASC, year ASC, month ASC, day ASC)""".format(self.keyspace_name) self.session.execute(create_mv_alltime) self.session.execute(create_mv_dailyhigh) self.session.execute(create_mv_monthlyhigh) self.session.execute(create_mv_filtereduserhigh) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.alltimehigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.dailyhigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.monthlyhigh".format(self.keyspace_name)) + self.addCleanup(self.session.execute, "DROP MATERIALIZED VIEW {0}.filtereduserhigh".format(self.keyspace_name)) + prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(self.keyspace_name)) bound = prepared_insert.bind(('pcmanus', 'Coup', 2015, 5, 1, 4000)) @@ -800,73 +1246,73 @@ def test_mv_filtering(self): query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 5) - self.assertEquals(results[0].day, 1) - self.assertEquals(results[0].score, 4000) - self.assertEquals(results[0].user, "pcmanus") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 5) + self.assertEqual(results[0].day, 1) + self.assertEqual(results[0].score, 4000) + self.assertEqual(results[0].user, "pcmanus") # Test prepared statement and daily high filtering prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) results = self.session.execute(bound_query) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 2) - self.assertEquals(results[0].score, 2000) - self.assertEquals(results[0].user, "pcmanus") - - self.assertEquals(results[1].game, 'Coup') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 6) - self.assertEquals(results[1].day, 2) - self.assertEquals(results[1].score, 1000) - self.assertEquals(results[1].user, "tjake") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 2) + self.assertEqual(results[0].score, 2000) + self.assertEqual(results[0].user, "pcmanus") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 2) + self.assertEqual(results[1].score, 1000) + self.assertEqual(results[1].user, "tjake") # Test montly high range queries prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) results = self.session.execute(bound_query) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 20) - self.assertEquals(results[0].score, 3500) - self.assertEquals(results[0].user, "jbellis") - - self.assertEquals(results[1].game, 'Coup') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 6) - self.assertEquals(results[1].day, 9) - self.assertEquals(results[1].score, 2700) - self.assertEquals(results[1].user, "jmckenzie") - - self.assertEquals(results[2].game, 'Coup') - self.assertEquals(results[2].year, 2015) - self.assertEquals(results[2].month, 6) - self.assertEquals(results[2].day, 1) - self.assertEquals(results[2].score, 2500) - self.assertEquals(results[2].user, "iamaleksey") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 20) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 9) + self.assertEqual(results[1].score, 2700) + self.assertEqual(results[1].user, "jmckenzie") + + self.assertEqual(results[2].game, 'Coup') + self.assertEqual(results[2].year, 2015) + self.assertEqual(results[2].month, 6) + self.assertEqual(results[2].day, 1) + self.assertEqual(results[2].score, 2500) + self.assertEqual(results[2].user, "iamaleksey") # Test filtered user high scores query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEquals(results[0].game, 'Chess') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 21) - self.assertEquals(results[0].score, 3500) - self.assertEquals(results[0].user, "jbellis") + self.assertEqual(results[0].game, 'Chess') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 21) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") - self.assertEquals(results[1].game, 'Chess') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 1) - self.assertEquals(results[1].day, 25) - self.assertEquals(results[1].score, 3200) - self.assertEquals(results[1].user, "pcmanus") + self.assertEqual(results[1].game, 'Chess') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 1) + self.assertEqual(results[1].day, 25) + self.assertEqual(results[1].score, 3200) + self.assertEqual(results[1].user, "pcmanus") class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): @@ -902,3 +1348,307 @@ def test_unicode(self): self.session.execute(bound) +class BaseKeyspaceTests(): + @classmethod + def setUpClass(cls): + cls.cluster = TestCluster() + cls.session = cls.cluster.connect(wait_for_all_pools=True) + cls.ks_name = cls.__name__.lower() + + cls.alternative_ks = "alternative_keyspace" + cls.table_name = "table_query_keyspace_tests" + + ddl = """CREATE KEYSPACE {0} WITH replication = + {{'class': 'SimpleStrategy', + 'replication_factor': '{1}'}}""".format(cls.ks_name, 1) + cls.session.execute(ddl) + + ddl = """CREATE KEYSPACE {0} WITH replication = + {{'class': 'SimpleStrategy', + 'replication_factor': '{1}'}}""".format(cls.alternative_ks, 1) + cls.session.execute(ddl) + + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(cls.ks_name, cls.table_name) + cls.session.execute(ddl) + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v int )'''.format(cls.alternative_ks, cls.table_name) + cls.session.execute(ddl) + + cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (1, 1)".format(cls.ks_name, cls.table_name)) + cls.session.execute("INSERT INTO {}.{} (k, v) VALUES (2, 2)".format(cls.alternative_ks, cls.table_name)) + + @classmethod + def tearDownClass(cls): + ddl = "DROP KEYSPACE {}".format(cls.alternative_ks) + cls.session.execute(ddl) + ddl = "DROP KEYSPACE {}".format(cls.ks_name) + cls.session.execute(ddl) + cls.cluster.shutdown() + + +class QueryKeyspaceTests(BaseKeyspaceTests): + + def test_setting_keyspace(self): + """ + Test the basic functionality of PYTHON-678, the keyspace can be set + independently of the query and read the results + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + self._check_set_keyspace_in_statement(self.session) + + @requirecassandra + @greaterthanorequalcass40 + def test_setting_keyspace_and_session(self): + """ + Test we can still send the keyspace independently even the session + connects to a keyspace when it's created + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = TestCluster(protocol_version=ProtocolVersion.V5, allow_beta_protocol_version=True) + session = cluster.connect(self.alternative_ks) + self.addCleanup(cluster.shutdown) + + self._check_set_keyspace_in_statement(session) + + def test_setting_keyspace_and_session_after_created(self): + """ + Test we can still send the keyspace independently even the session + connects to a different keyspace after being created + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = TestCluster() + session = cluster.connect() + self.addCleanup(cluster.shutdown) + + session.set_keyspace(self.alternative_ks) + self._check_set_keyspace_in_statement(session) + + def test_setting_keyspace_and_same_session(self): + """ + Test we can still send the keyspace independently even if the session + is connected to the sent keyspace + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = TestCluster() + session = cluster.connect(self.ks_name) + self.addCleanup(cluster.shutdown) + + self._check_set_keyspace_in_statement(session) + + +@greaterthanorequalcass40 +class SimpleWithKeyspaceTests(QueryKeyspaceTests, unittest.TestCase): + @unittest.skip + def test_lower_protocol(self): + cluster = TestCluster(protocol_version=ProtocolVersion.V4) + session = cluster.connect(self.ks_name) + self.addCleanup(cluster.shutdown) + + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + # This raises cassandra.cluster.NoHostAvailable: ('Unable to complete the operation against + # any hosts', {: UnsupportedOperation('Keyspaces may only be + # set on queries with protocol version 5 or higher. Consider setting Cluster.protocol_version to 5.',), + # : ConnectionException('Host has been marked down or removed',), + # : ConnectionException('Host has been marked down or removed',)}) + with self.assertRaises(NoHostAvailable): + session.execute(simple_stmt) + + def _check_set_keyspace_in_statement(self, session): + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) + results = session.execute(simple_stmt) + self.assertEqual(results[0], (1, 1)) + + simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name)) + simple_stmt.keyspace = self.ks_name + results = session.execute(simple_stmt) + self.assertEqual(results[0], (1, 1)) + + +@greaterthanorequalcass40 +class BatchWithKeyspaceTests(QueryKeyspaceTests, unittest.TestCase): + def _check_set_keyspace_in_statement(self, session): + batch_stmt = BatchStatement() + for i in range(10): + batch_stmt.add("INSERT INTO {} (k, v) VALUES (%s, %s)".format(self.table_name), (i, i)) + + batch_stmt.keyspace = self.ks_name + session.execute(batch_stmt) + self.confirm_results() + + def confirm_results(self): + keys = set() + values = set() + # Assuming the test data is inserted at default CL.ONE, we need ALL here to guarantee we see + # everything inserted + results = self.session.execute(SimpleStatement("SELECT * FROM {}.{}".format(self.ks_name, self.table_name), + consistency_level=ConsistencyLevel.ALL)) + for result in results: + keys.add(result.k) + values.add(result.v) + + self.assertEqual(set(range(10)), keys, msg=results) + self.assertEqual(set(range(10)), values, msg=results) + + +@greaterthanorequalcass40 +class PreparedWithKeyspaceTests(BaseKeyspaceTests, unittest.TestCase): + + def setUp(self): + self.cluster = TestCluster() + self.session = self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_prepared_with_keyspace_explicit(self): + """ + Test the basic functionality of PYTHON-678, the keyspace can be set + independently of the query and read the results + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + query = "SELECT * from {} WHERE k = ?".format(self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.ks_name) + + results = self.session.execute(prepared_statement, (1, )) + self.assertEqual(results[0], (1, 1)) + + prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + + self.assertNotEqual(prepared_statement.query_id, prepared_statement_alternative.query_id) + + results = self.session.execute(prepared_statement_alternative, (2,)) + self.assertEqual(results[0], (2, 2)) + + def test_reprepare_after_host_is_down(self): + """ + Test that Cluster._prepare_all_queries is called and the + when a node comes up and the queries succeed later + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + mock_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(mock_handler) + get_node(1).stop(wait=True, gently=True, wait_other_notice=True) + + only_first = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"])) + self.cluster.add_execution_profile("only_first", only_first) + + query = "SELECT v from {} WHERE k = ?".format(self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.ks_name) + prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) + + get_node(1).start(wait_for_binary_proto=True, wait_other_notice=True) + + # We wait for cluster._prepare_all_queries to be called + time.sleep(5) + self.assertEqual(1, mock_handler.get_message_count('debug', 'Preparing all known prepared statements')) + results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") + self.assertEqual(results[0], (1, )) + + results = self.session.execute(prepared_statement_alternative, (2,), execution_profile="only_first") + self.assertEqual(results[0], (2, )) + + def test_prepared_not_found(self): + """ + Test to if a query fails on a node that didn't have + the query prepared, it is re-prepared as expected and then + the query is executed + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the query is executed and the results retrieved + + @test_category query + """ + cluster = TestCluster() + session = self.cluster.connect("system") + self.addCleanup(cluster.shutdown) + + cluster.prepare_on_all_hosts = False + query = "SELECT k from {} WHERE k = ?".format(self.table_name) + prepared_statement = session.prepare(query, keyspace=self.ks_name) + + for _ in range(10): + results = session.execute(prepared_statement, (1, )) + self.assertEqual(results[0], (1,)) + + def test_prepared_in_query_keyspace(self): + """ + Test to the the keyspace can be set in the query + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the results are retrieved correctly + + @test_category query + """ + cluster = TestCluster() + session = self.cluster.connect() + self.addCleanup(cluster.shutdown) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = session.prepare(query) + results = session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.alternative_ks, self.table_name) + prepared_statement = session.prepare(query) + results = session.execute(prepared_statement, (2,)) + self.assertEqual(results[0], (2,)) + + def test_prepared_in_query_keyspace_and_explicit(self): + """ + Test to the the keyspace set explicitly is ignored if it is + specified as well in the query + + @since 3.12 + @jira_ticket PYTHON-678 + @expected_result the keyspace set explicitly is ignored and + the results are retrieved correctly + + @test_category query + """ + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = self.session.prepare(query, keyspace="system") + results = self.session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) + + query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) + prepared_statement = self.session.prepare(query, keyspace=self.alternative_ks) + results = self.session.execute(prepared_statement, (1,)) + self.assertEqual(results[0], (1,)) diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index fdad1bc3ee..465ef8b601 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -1,32 +1,30 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster import logging log = logging.getLogger(__name__) -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from itertools import cycle, count -from six.moves import range from threading import Event -from cassandra.cluster import Cluster +from cassandra import ConsistencyLevel +from cassandra.cluster import EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.policies import HostDistance from cassandra.query import SimpleStatement @@ -44,10 +42,12 @@ def setUp(self): "Protocol 2.0+ is required for Paging state, currently testing against %r" % (PROTOCOL_VERSION,)) - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + self.cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(consistency_level=ConsistencyLevel.LOCAL_QUORUM)} + ) if PROTOCOL_VERSION < 3: self.cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) - self.session = self.cluster.connect() + self.session = self.cluster.connect(wait_for_all_pools=True) self.session.execute("TRUNCATE test3rf.test") def tearDown(self): @@ -69,6 +69,34 @@ def test_paging(self): self.assertEqual(100, len(list(self.session.execute(prepared)))) + def test_paging_state(self): + """ + Test to validate paging state api + @since 3.7.0 + @jira_ticket PYTHON-200 + @expected_result paging state should returned should be accurate, and allow for queries to be resumed. + + @test_category queries + """ + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + execute_concurrent(self.session, list(statements_and_params)) + + list_all_results = [] + self.session.default_fetch_size = 3 + + result_set = self.session.execute("SELECT * FROM test3rf.test") + while(result_set.has_more_pages): + for row in result_set.current_rows: + self.assertNotIn(row, list_all_results) + list_all_results.extend(result_set.current_rows) + page_state = result_set.paging_state + result_set = self.session.execute("SELECT * FROM test3rf.test", paging_state=page_state) + + if(len(result_set.current_rows) > 0): + list_all_results.append(result_set.current_rows) + self.assertEqual(len(list_all_results), 100) + def test_paging_verify_writes(self): statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), [(i, ) for i in range(100)]) @@ -226,6 +254,15 @@ def test_async_paging_verify_writes(self): self.assertSequenceEqual(range(1, 101), value_array) def test_paging_callbacks(self): + """ + Test to validate callback api + @since 3.9.0 + @jira_ticket PYTHON-733 + @expected_result callbacks shouldn't be called twice per message + and the fetch_size should be handled in a transparent way to the user + + @test_category queries + """ statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), [(i, ) for i in range(100)]) execute_concurrent(self.session, list(statements_and_params)) @@ -234,12 +271,14 @@ def test_paging_callbacks(self): for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): self.session.default_fetch_size = fetch_size - future = self.session.execute_async("SELECT * FROM test3rf.test") + future = self.session.execute_async("SELECT * FROM test3rf.test", timeout=20) event = Event() counter = count() + number_of_calls = count() - def handle_page(rows, future, counter): + def handle_page(rows, future, counter, number_of_calls): + next(number_of_calls) for row in rows: next(counter) @@ -252,26 +291,34 @@ def handle_error(err): event.set() self.fail(err) - future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) self.assertEqual(next(counter), 100) # simple statement - future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test")) + future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"), timeout=20) event.clear() counter = count() + number_of_calls = count() - future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) self.assertEqual(next(counter), 100) # prepared statement - future = self.session.execute_async(prepared) + future = self.session.execute_async(prepared, timeout=20) event.clear() counter = count() + number_of_calls = count() - future.add_callbacks(callback=handle_page, callback_args=(future, counter), errback=handle_error) + future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), + errback=handle_error) event.wait() + self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) self.assertEqual(next(counter), 100) def test_concurrent_with_paging(self): diff --git a/tests/integration/standard/test_routing.py b/tests/integration/standard/test_routing.py index 1ad0a2fa5d..d41e06df6b 100644 --- a/tests/integration/standard/test_routing.py +++ b/tests/integration/standard/test_routing.py @@ -1,29 +1,26 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from uuid import uuid1 import logging log = logging.getLogger(__name__) -from cassandra.cluster import Cluster - -from tests.integration import use_singledc, PROTOCOL_VERSION +from tests.integration import use_singledc, TestCluster def setup_module(): @@ -38,7 +35,7 @@ def cfname(self): @classmethod def setup_class(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect('test1rf') @classmethod @@ -74,7 +71,7 @@ def create_prepare(self, key_types): select = s.prepare("SELECT token(%s) FROM %s WHERE %s" % (primary_key, table_name, where_clause)) - return (insert, select) + return insert, select def test_singular_key(self): # string diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py index a18646489b..97f16ea106 100644 --- a/tests/integration/standard/test_row_factories.py +++ b/tests/integration/standard/test_row_factories.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCaseWFunctionTable +from tests.integration import get_server_versions, use_singledc, \ + BasicSharedKeyspaceUnitTestCaseWFunctionTable, BasicSharedKeyspaceUnitTestCase, execute_until_pass, TestCluster -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from cassandra.cluster import Cluster, ResultSet +from cassandra.cluster import ResultSet, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.query import tuple_factory, named_tuple_factory, dict_factory, ordered_dict_factory from cassandra.util import OrderedDict @@ -28,42 +28,72 @@ def setup_module(): use_singledc() +class NameTupleFactory(BasicSharedKeyspaceUnitTestCase): + + def setUp(self): + self.common_setup(1, execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=named_tuple_factory)}) + ddl = ''' + CREATE TABLE {0}.{1} ( + k int PRIMARY KEY, + v1 text, + v2 text, + v3 text)'''.format(self.ks_name, self.function_table_name) + self.session.execute(ddl) + execute_until_pass(self.session, ddl) + + def test_sanitizing(self): + """ + Test to ensure that same named results are surfaced in the NamedTupleFactory + + Creates a table with a few different text fields. Inserts a few values in that table. + It then fetches the values and confirms that despite all be being selected as the same name + they are propagated in the result set differently. + + @since 3.3 + @jira_ticket PYTHON-467 + @expected_result duplicate named results have unique row names. + + @test_category queries + """ + + for x in range(5): + insert1 = ''' + INSERT INTO {0}.{1} + ( k , v1, v2, v3 ) + VALUES + ( 1 , 'v1{2}', 'v2{2}','v3{2}' ) + '''.format(self.keyspace_name, self.function_table_name, str(x)) + self.session.execute(insert1) + + query = "SELECT v1 AS duplicate, v2 AS duplicate, v3 AS duplicate from {0}.{1}".format(self.ks_name, self.function_table_name) + rs = self.session.execute(query) + row = rs[0] + self.assertTrue(hasattr(row, 'duplicate')) + self.assertTrue(hasattr(row, 'duplicate_')) + self.assertTrue(hasattr(row, 'duplicate__')) + + class RowFactoryTests(BasicSharedKeyspaceUnitTestCaseWFunctionTable): """ Test different row_factories and access code """ - def setUp(self): - super(RowFactoryTests, self).setUp() - self.insert1 = ''' - INSERT INTO {0}.{1} - ( k , v ) - VALUES - ( 1 , 1 ) - '''.format(self.keyspace_name, self.function_table_name) - - self.insert2 = ''' - INSERT INTO {0}.{1} - ( k , v ) - VALUES - ( 2 , 2 ) - '''.format(self.keyspace_name, self.function_table_name) - - self.select = ''' - SELECT * FROM {0}.{1} - '''.format(self.keyspace_name, self.function_table_name) - - def tearDown(self): - self.drop_function_table() + @classmethod + def setUpClass(cls): + cls.common_setup(rf=1, create_class_table=True) + q = "INSERT INTO {0}.{1} (k, v) VALUES (%s, %s)".format(cls.ks_name, cls.ks_name) + cls.session.execute(q, (1, 1)) + cls.session.execute(q, (2, 2)) + cls.select = "SELECT * FROM {0}.{1}".format(cls.ks_name, cls.ks_name) + + def _results_from_row_factory(self, row_factory): + cluster = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=row_factory)} + ) + with cluster: + return cluster.connect().execute(self.select) def test_tuple_factory(self): - session = self.session - session.row_factory = tuple_factory - - session.execute(self.insert1) - session.execute(self.insert2) - - result = session.execute(self.select) - + result = self._results_from_row_factory(tuple_factory) self.assertIsInstance(result, ResultSet) self.assertIsInstance(result[0], tuple) @@ -76,14 +106,7 @@ def test_tuple_factory(self): self.assertEqual(result[1][0], 2) def test_named_tuple_factory(self): - session = self.session - session.row_factory = named_tuple_factory - - session.execute(self.insert1) - session.execute(self.insert2) - - result = session.execute(self.select) - + result = self._results_from_row_factory(named_tuple_factory) self.assertIsInstance(result, ResultSet) result = list(result) @@ -95,17 +118,10 @@ def test_named_tuple_factory(self): self.assertEqual(result[1].k, result[1].v) self.assertEqual(result[1].k, 2) - def test_dict_factory(self): - session = self.session - session.row_factory = dict_factory - - session.execute(self.insert1) - session.execute(self.insert2) - - result = session.execute(self.select) - + def _test_dict_factory(self, row_factory, row_type): + result = self._results_from_row_factory(row_factory) self.assertIsInstance(result, ResultSet) - self.assertIsInstance(result[0], dict) + self.assertIsInstance(result[0], row_type) for row in result: self.assertEqual(row['k'], row['v']) @@ -115,25 +131,42 @@ def test_dict_factory(self): self.assertEqual(result[1]['k'], result[1]['v']) self.assertEqual(result[1]['k'], 2) + def test_dict_factory(self): + self._test_dict_factory(dict_factory, dict) + def test_ordered_dict_factory(self): - session = self.session - session.row_factory = ordered_dict_factory + self._test_dict_factory(ordered_dict_factory, OrderedDict) - session.execute(self.insert1) - session.execute(self.insert2) + def test_generator_row_factory(self): + """ + Test that ResultSet.one() works with a row_factory that contains a generator. - result = session.execute(self.select) + @since 3.16 + @jira_ticket PYTHON-1026 + @expected_result one() returns the first row - self.assertIsInstance(result, ResultSet) - self.assertIsInstance(result[0], OrderedDict) + @test_category queries + """ + def generator_row_factory(column_names, rows): + return _gen_row_factory(rows) - for row in result: - self.assertEqual(row['k'], row['v']) + def _gen_row_factory(rows): + for r in rows: + yield r - self.assertEqual(result[0]['k'], result[0]['v']) - self.assertEqual(result[0]['k'], 1) - self.assertEqual(result[1]['k'], result[1]['v']) - self.assertEqual(result[1]['k'], 2) + session = self.session + session.row_factory = generator_row_factory + + session.execute(''' + INSERT INTO {0}.{1} + ( k , v ) + VALUES + ( 1 , 1 ) + '''.format(self.keyspace_name, self.function_table_name)) + result = session.execute(self.select) + self.assertIsInstance(result, ResultSet) + first_row = result.one() + self.assertEqual(first_row[0], first_row[1]) class NamedTupleFactoryAndNumericColNamesTests(unittest.TestCase): @@ -142,12 +175,11 @@ class NamedTupleFactoryAndNumericColNamesTests(unittest.TestCase): """ @classmethod def setup_class(cls): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + cls.cluster = TestCluster() cls.session = cls.cluster.connect() cls._cass_version, cls._cql_version = get_server_versions() ddl = ''' - CREATE TABLE test1rf.table_num_col ( key blob PRIMARY KEY, "626972746864617465" blob ) - WITH COMPACT STORAGE''' + CREATE TABLE test1rf.table_num_col ( key blob PRIMARY KEY, "626972746864617465" blob )''' cls.session.execute(ddl) @classmethod @@ -180,8 +212,10 @@ def test_can_select_with_dict_factory(self): """ can SELECT numeric column using dict_factory """ - self.session.row_factory = dict_factory - try: - self.session.execute('SELECT * FROM test1rf.table_num_col') - except ValueError as e: - self.fail("Unexpected ValueError exception: %s" % e.message) + with TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) as cluster: + try: + cluster.connect().execute('SELECT * FROM test1rf.table_num_col') + except ValueError as e: + self.fail("Unexpected ValueError exception: %s" % e.message) diff --git a/tests/integration/standard/test_single_interface.py b/tests/integration/standard/test_single_interface.py new file mode 100644 index 0000000000..6ff331060a --- /dev/null +++ b/tests/integration/standard/test_single_interface.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra import ConsistencyLevel +from cassandra.query import SimpleStatement + +from packaging.version import Version +from tests.integration import use_singledc, PROTOCOL_VERSION, \ + remove_cluster, greaterthanorequalcass40, notdse, \ + CASSANDRA_VERSION, DSE_VERSION, TestCluster, DEFAULT_SINGLE_INTERFACE_PORT + + +def setup_module(): + if not DSE_VERSION and CASSANDRA_VERSION >= Version('4-a'): + remove_cluster() + use_singledc(use_single_interface=True) + +def teardown_module(): + remove_cluster() + + +@notdse +@greaterthanorequalcass40 +class SingleInterfaceTest(unittest.TestCase): + + def setUp(self): + self.cluster = TestCluster(port=DEFAULT_SINGLE_INTERFACE_PORT) + self.session = self.cluster.connect() + + def tearDown(self): + if self.cluster is not None: + self.cluster.shutdown() + + def test_single_interface(self): + """ + Test that we can connect to a multiple hosts bound to a single interface. + """ + hosts = self.cluster.metadata._hosts + broadcast_rpc_ports = [] + broadcast_ports = [] + self.assertEqual(len(hosts), 3) + for endpoint, host in hosts.items(): + + self.assertEqual(endpoint.address, host.broadcast_rpc_address) + self.assertEqual(endpoint.port, host.broadcast_rpc_port) + + if host.broadcast_rpc_port in broadcast_rpc_ports: + self.fail("Duplicate broadcast_rpc_port") + broadcast_rpc_ports.append(host.broadcast_rpc_port) + if host.broadcast_port in broadcast_ports: + self.fail("Duplicate broadcast_port") + broadcast_ports.append(host.broadcast_port) + + for _ in range(1, 100): + self.session.execute(SimpleStatement("select * from system_distributed.view_build_status", + consistency_level=ConsistencyLevel.ALL)) + + for pool in self.session.get_pools(): + self.assertEqual(1, pool.get_state()['open_count']) diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 736e7957e2..3a6de0d4b7 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,28 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from datetime import datetime +import ipaddress import math -import six +import random +import string +import socket +import uuid + +from datetime import datetime, date, time, timedelta +from decimal import Decimal +from functools import partial + +from packaging.version import Version import cassandra from cassandra import InvalidRequest -from cassandra.cluster import Cluster +from cassandra import util +from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.concurrent import execute_concurrent_with_args from cassandra.cqltypes import Int32Type, EMPTY from cassandra.query import dict_factory, ordered_dict_factory -from cassandra.util import sortedset +from cassandra.util import sortedset, Duration from tests.unit.cython.utils import cythontest -from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \ - BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30 -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ - get_sample, get_collection_sample +from tests.integration import use_singledc, execute_until_pass, notprotocolv1, \ + BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, greaterthanorequaldse51, \ + DSE_VERSION, greaterthanorequalcass3_10, requiredse, TestCluster, greaterthanorequalcass50 +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \ + get_sample, get_all_samples, get_collection_sample def setup_module(): @@ -60,25 +70,7 @@ def test_can_insert_blob_type_as_string(self): params = ['key1', b'blobbyblob'] query = "INSERT INTO blobstring (a, b) VALUES (%s, %s)" - # In python2, with Cassandra > 2.0, we don't treat the 'byte str' type as a blob, so we'll encode it - # as a string literal and have the following failure. - if six.PY2 and self.cql_version >= (3, 1, 0): - # Blob values can't be specified using string notation in CQL 3.1.0 and - # above which is used by default in Cassandra 2.0. - if self.cass_version >= (2, 1, 0): - msg = r'.*Invalid STRING constant \(.*?\) for "b" of type blob.*' - else: - msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*' - self.assertRaisesRegexp(InvalidRequest, msg, s.execute, query, params) - return - - # In python2, with Cassandra < 2.0, we can manually encode the 'byte str' type as hex for insertion in a blob. - if six.PY2: - cass_params = [params[0], params[1].encode('hex')] - s.execute(query, cass_params) - # In python 3, the 'bytes' type is treated as a blob, so we can correctly encode it with hex notation. - else: - s.execute(query, params) + s.execute(query, params) results = s.execute("SELECT * FROM blobstring")[0] for expected, actual in zip(params, results): @@ -133,7 +125,7 @@ def test_can_insert_primitive_datatypes(self): """ Test insertion of all datatype primitives """ - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster() s = c.connect(self.keyspace_name) # create table @@ -161,8 +153,29 @@ def test_can_insert_primitive_datatypes(self): for expected, actual in zip(params, results): self.assertEqual(actual, expected) + # try the same thing sending one insert at the time + s.execute("TRUNCATE alltypes;") + for i, datatype in enumerate(PRIMITIVE_DATATYPES): + single_col_name = chr(start_index + i) + single_col_names = ["zz", single_col_name] + placeholders = ','.join(["%s"] * len(single_col_names)) + single_columns_string = ', '.join(single_col_names) + for j, data_sample in enumerate(get_all_samples(datatype)): + key = i + 1000 * j + single_params = (key, data_sample) + s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(single_columns_string, placeholders), + single_params) + # verify data + result = s.execute("SELECT {0} FROM alltypes WHERE zz=%s".format(single_columns_string), (key,))[0][1] + compare_value = data_sample + + if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address): + compare_value = str(data_sample) + self.assertEqual(result, compare_value) + # try the same thing with a prepared statement placeholders = ','.join(["?"] * len(col_names)) + s.execute("TRUNCATE alltypes;") insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders)) s.execute(insert.bind(params)) @@ -178,9 +191,9 @@ def test_can_insert_primitive_datatypes(self): self.assertEqual(actual, expected) # verify data with with prepared statement, use dictionary with no explicit columns - s.row_factory = ordered_dict_factory select = s.prepare("SELECT * FROM alltypes") - results = s.execute(select)[0] + results = s.execute(select, + execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory))[0] for expected, actual in zip(params, results.values()): self.assertEqual(actual, expected) @@ -192,7 +205,7 @@ def test_can_insert_collection_datatypes(self): Test insertion of all collection types """ - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster() s = c.connect(self.keyspace_name) # use tuple encoding, to convert native python tuple into raw CQL s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple @@ -202,7 +215,7 @@ def test_can_insert_collection_datatypes(self): col_names = ["zz"] start_index = ord('a') for i, collection_type in enumerate(COLLECTION_TYPES): - for j, datatype in enumerate(PRIMITIVE_DATATYPES): + for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): if collection_type == "map": type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), collection_type, datatype) @@ -221,7 +234,7 @@ def test_can_insert_collection_datatypes(self): # create the input for simple statement params = [0] for collection_type in COLLECTION_TYPES: - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: params.append((get_collection_sample(collection_type, datatype))) # insert into table as a simple statement @@ -236,7 +249,7 @@ def test_can_insert_collection_datatypes(self): # create the input for prepared statement params = [0] for collection_type in COLLECTION_TYPES: - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: params.append((get_collection_sample(collection_type, datatype))) # try the same thing with a prepared statement @@ -256,9 +269,10 @@ def test_can_insert_collection_datatypes(self): self.assertEqual(actual, expected) # verify data with with prepared statement, use dictionary with no explicit columns - s.row_factory = ordered_dict_factory select = s.prepare("SELECT * FROM allcoltypes") - results = s.execute(select)[0] + results = s.execute(select, + execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, + row_factory=ordered_dict_factory))[0] for expected, actual in zip(params, results.values()): self.assertEqual(actual, expected) @@ -423,7 +437,7 @@ def test_can_insert_tuples(self): if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster() s = c.connect(self.keyspace_name) # use this encoder in order to insert tuples @@ -475,12 +489,12 @@ def test_can_insert_tuples_with_varying_lengths(self): if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) s = c.connect(self.keyspace_name) - # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples - s.row_factory = dict_factory s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple # programmatically create the table with tuples of said sizes @@ -514,7 +528,7 @@ def test_can_insert_tuples_all_primitive_datatypes(self): if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster() s = c.connect(self.keyspace_name) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple @@ -542,26 +556,26 @@ def test_can_insert_tuples_all_collection_datatypes(self): if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) s = c.connect(self.keyspace_name) - # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples - s.row_factory = dict_factory s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple values = [] # create list values - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: values.append('v_{0} frozen>>'.format(len(values), datatype)) # create set values - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: values.append('v_{0} frozen>>'.format(len(values), datatype)) # create map values - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: datatype_1 = datatype_2 = datatype if datatype == 'blob': # unhashable type: 'bytearray' @@ -581,7 +595,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): i = 0 # test tuple> - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([[get_sample(datatype)]]) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) @@ -590,7 +604,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): i += 1 # test tuple> - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: created_tuple = tuple([sortedset([get_sample(datatype)])]) s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) @@ -599,7 +613,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): i += 1 # test tuple> - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: if datatype == 'blob': # unhashable type: 'bytearray' created_tuple = tuple([{get_sample('ascii'): get_sample(datatype)}]) @@ -641,12 +655,12 @@ def test_can_insert_nested_tuples(self): if self.cass_version < (2, 1, 0): raise unittest.SkipTest("The tuple type was introduced in Cassandra 2.1") - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) s = c.connect(self.keyspace_name) - # set the row_factory to dict_factory for programmatic access # set the encoder for tuples for the ability to write tuples - s.row_factory = dict_factory s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple # create a table with multiple sizes of nested tuples @@ -793,10 +807,425 @@ def test_cython_decimal(self): finally: self.session.execute("DROP TABLE {0}".format(self.function_table_name)) + @greaterthanorequalcass3_10 + def test_smoke_duration_values(self): + """ + Test to write several Duration values to the database and verify + they can be read correctly. The verify than an exception is arisen + if the value is too big + + @since 3.10 + @jira_ticket PYTHON-747 + @expected_result the read value in C* matches the written one + + @test_category data_types serialization + """ + self.session.execute(""" + CREATE TABLE duration_smoke (k int primary key, v duration) + """) + self.addCleanup(self.session.execute, "DROP TABLE duration_smoke") + + prepared = self.session.prepare(""" + INSERT INTO duration_smoke (k, v) + VALUES (?, ?) + """) + + nanosecond_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, + 10000000000000,-9223372036854775807, 9223372036854775807, + int("7FFFFFFFFFFFFFFF", 16), int("-7FFFFFFFFFFFFFFF", 16)] + month_day_smoke_values = [0, -1, 1, 100, 1000, 1000000, 1000000000, + int("7FFFFFFF", 16), int("-7FFFFFFF", 16)] + + for nanosecond_value in nanosecond_smoke_values: + for month_day_value in month_day_smoke_values: + + # Must have the same sign + if (month_day_value <= 0) != (nanosecond_value <= 0): + continue + + self.session.execute(prepared, (1, Duration(month_day_value, month_day_value, nanosecond_value))) + results = self.session.execute("SELECT * FROM duration_smoke") + + v = results[0][1] + self.assertEqual(Duration(month_day_value, month_day_value, nanosecond_value), v, + "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value)) + + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16)))) + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0))) + self.assertRaises(ValueError, self.session.execute, prepared, + (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0))) + + +@requiredse +class AbstractDateRangeTest(): + + def test_single_value_daterange_round_trip(self): + self._daterange_round_trip( + util.DateRange( + value=util.DateRangeBound( + datetime(2014, 10, 1, 0), + util.DateRangePrecision.YEAR + ) + ), + util.DateRange( + value=util.DateRangeBound( + datetime(2014, 1, 1, 0), + util.DateRangePrecision.YEAR + ) + ) + ) + + def test_open_high_daterange_round_trip(self): + self._daterange_round_trip( + util.DateRange( + lower_bound=util.DateRangeBound( + datetime(2013, 10, 1, 6, 20, 39), + util.DateRangePrecision.SECOND + ) + ) + ) + + def test_open_low_daterange_round_trip(self): + self._daterange_round_trip( + util.DateRange( + upper_bound=util.DateRangeBound( + datetime(2013, 10, 28), + util.DateRangePrecision.DAY + ) + ) + ) + + def test_open_both_daterange_round_trip(self): + self._daterange_round_trip( + util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND, + ) + ) + + def test_closed_daterange_round_trip(self): + insert = util.DateRange( + lower_bound=util.DateRangeBound( + datetime(2015, 3, 1, 10, 15, 30, 1000), + util.DateRangePrecision.MILLISECOND + ), + upper_bound=util.DateRangeBound( + datetime(2016, 1, 1, 10, 15, 30, 999000), + util.DateRangePrecision.MILLISECOND + ) + ) + self._daterange_round_trip(insert) + + def test_epoch_value_round_trip(self): + insert = util.DateRange( + value=util.DateRangeBound( + datetime(1970, 1, 1), + util.DateRangePrecision.YEAR + ) + ) + self._daterange_round_trip(insert) + + def test_double_bounded_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '[2015-03-01T10:15:30.010Z TO 2016-01-01T10:15:30.999Z]', + util.DateRange( + lower_bound=util.DateRangeBound( + datetime(2015, 3, 1, 10, 15, 30, 10000), + util.DateRangePrecision.MILLISECOND + ), + upper_bound=util.DateRangeBound( + datetime(2016, 1, 1, 10, 15, 30, 999000), + util.DateRangePrecision.MILLISECOND + ), + ) + ) + + def test_open_high_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '[2015-03 TO *]', + util.DateRange( + lower_bound=util.DateRangeBound( + datetime(2015, 3, 1, 0, 0), + util.DateRangePrecision.MONTH + ), + upper_bound=util.DateRangeBound(None, None) + ) + ) + + def test_open_low_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '[* TO 2015-03]', + util.DateRange( + lower_bound=util.DateRangeBound(None, None), + upper_bound=util.DateRangeBound( + datetime(2015, 3, 1, 0, 0), + 'MONTH' + ) + ) + ) + + def test_no_bounds_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '[* TO *]', + util.DateRange( + lower_bound=(None, None), + upper_bound=(None, None) + ) + ) + + def test_single_no_bounds_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '*', + util.DateRange( + value=(None, None) + ) + ) + + def test_single_value_daterange_round_trip_from_string(self): + self._daterange_round_trip( + '2001-01-01T12:30:30.000Z', + util.DateRange( + value=util.DateRangeBound( + datetime(2001, 1, 1, 12, 30, 30), + 'MILLISECOND' + ) + ) + ) + + def test_daterange_with_negative_bound_round_trip_from_string(self): + self._daterange_round_trip( + '[-1991-01-01T00:00:00.001 TO 1990-02-03]', + util.DateRange( + lower_bound=(-124997039999999, 'MILLISECOND'), + upper_bound=util.DateRangeBound( + datetime(1990, 2, 3, 12, 30, 30), + 'DAY' + ) + ) + ) + + def test_epoch_value_round_trip_from_string(self): + self._daterange_round_trip( + '1970', + util.DateRange( + value=util.DateRangeBound( + datetime(1970, 1, 1), + util.DateRangePrecision.YEAR + ) + ) + ) + + +@greaterthanorequaldse51 +class TestDateRangePrepared(AbstractDateRangeTest, BasicSharedKeyspaceUnitTestCase): + """ + Tests various inserts and queries using Date-ranges and prepared queries + + @since 2.0.0 + @jira_ticket PYTHON-668 + @expected_result Date ranges will be inserted and retrieved succesfully + + @test_category data_types + """ + + @classmethod + def setUpClass(cls): + super(TestDateRangePrepared, cls).setUpClass() + cls.session.set_keyspace(cls.ks_name) + if DSE_VERSION and DSE_VERSION >= Version('5.1'): + cls.session.execute("CREATE TABLE tab (dr 'DateRangeType' PRIMARY KEY)") + + + def _daterange_round_trip(self, to_insert, expected=None): + if isinstance(to_insert, util.DateRange): + prep = self.session.prepare("INSERT INTO tab (dr) VALUES (?);") + self.session.execute(prep, (to_insert,)) + prep_sel = self.session.prepare("SELECT * FROM tab WHERE dr = ? ") + results = self.session.execute(prep_sel, (to_insert,)) + else: + prep = self.session.prepare("INSERT INTO tab (dr) VALUES ('%s');" % (to_insert,)) + self.session.execute(prep) + prep_sel = self.session.prepare("SELECT * FROM tab WHERE dr = '%s' " % (to_insert,)) + results = self.session.execute(prep_sel) + + dr = results[0].dr + # sometimes this is truncated in the assertEqual output on failure; + if isinstance(expected, str): + self.assertEqual(str(dr), expected) + else: + self.assertEqual(dr, expected or to_insert) + + # This can only be run as a prepared statement + def test_daterange_wide(self): + self._daterange_round_trip( + util.DateRange( + lower_bound=(-9223372036854775808, 'MILLISECOND'), + upper_bound=(9223372036854775807, 'MILLISECOND') + ), + '[-9223372036854775808ms TO 9223372036854775807ms]' + ) + # This can only be run as a prepared statement + def test_daterange_with_negative_bound_round_trip_to_string(self): + self._daterange_round_trip( + util.DateRange( + lower_bound=(-124997039999999, 'MILLISECOND'), + upper_bound=util.DateRangeBound( + datetime(1990, 2, 3, 12, 30, 30), + 'DAY' + ) + ), + '[-124997039999999ms TO 1990-02-03]' + ) + +@greaterthanorequaldse51 +class TestDateRangeSimple(AbstractDateRangeTest, BasicSharedKeyspaceUnitTestCase): + """ + Tests various inserts and queries using Date-ranges and simple queries + + @since 2.0.0 + @jira_ticket PYTHON-668 + @expected_result DateRanges will be inserted and retrieved successfully + @test_category data_types + """ + @classmethod + def setUpClass(cls): + super(TestDateRangeSimple, cls).setUpClass() + cls.session.set_keyspace(cls.ks_name) + if DSE_VERSION and DSE_VERSION >= Version('5.1'): + cls.session.execute("CREATE TABLE tab (dr 'DateRangeType' PRIMARY KEY)") + + + def _daterange_round_trip(self, to_insert, expected=None): + + query = "INSERT INTO tab (dr) VALUES ('{0}');".format(to_insert) + self.session.execute("INSERT INTO tab (dr) VALUES ('{0}');".format(to_insert)) + query = "SELECT * FROM tab WHERE dr = '{0}' ".format(to_insert) + results= self.session.execute("SELECT * FROM tab WHERE dr = '{0}' ".format(to_insert)) + + dr = results[0].dr + # sometimes this is truncated in the assertEqual output on failure; + if isinstance(expected, str): + self.assertEqual(str(dr), expected) + else: + self.assertEqual(dr, expected or to_insert) + + +@greaterthanorequaldse51 +class TestDateRangeCollection(BasicSharedKeyspaceUnitTestCase): + + + @classmethod + def setUpClass(cls): + super(TestDateRangeCollection, cls).setUpClass() + cls.session.set_keyspace(cls.ks_name) + + def test_date_range_collection(self): + """ + Tests DateRange type in collections + + @since 2.0.0 + @jira_ticket PYTHON-668 + @expected_result DateRanges will be inserted and retrieved successfully when part of a list or map + @test_category data_types + """ + self.session.execute("CREATE TABLE dateRangeIntegrationTest5 (k int PRIMARY KEY, l list<'DateRangeType'>, s set<'DateRangeType'>, dr2i map<'DateRangeType', int>, i2dr map)") + self.session.execute("INSERT INTO dateRangeIntegrationTest5 (k, l, s, i2dr, dr2i) VALUES (" + + "1, " + + "['[2000-01-01T10:15:30.001Z TO 2020]', '[2010-01-01T10:15:30.001Z TO 2020]', '2001-01-02'], " + + "{'[2000-01-01T10:15:30.001Z TO 2020]', '[2000-01-01T10:15:30.001Z TO 2020]', '[2010-01-01T10:15:30.001Z TO 2020]'}, " + + "{1: '[2000-01-01T10:15:30.001Z TO 2020]', 2: '[2010-01-01T10:15:30.001Z TO 2020]'}, " + + "{'[2000-01-01T10:15:30.001Z TO 2020]': 1, '[2010-01-01T10:15:30.001Z TO 2020]': 2})") + results = list(self.session.execute("SELECT * FROM dateRangeIntegrationTest5")) + self.assertEqual(len(results),1) + + lower_bound_1 = util.DateRangeBound(datetime(2000, 1, 1, 10, 15, 30, 1000), 'MILLISECOND') + + lower_bound_2 = util.DateRangeBound(datetime(2010, 1, 1, 10, 15, 30, 1000), 'MILLISECOND') + + upper_bound_1 = util.DateRangeBound(datetime(2020, 1, 1), 'YEAR') + + value_1 = util.DateRangeBound(datetime(2001, 1, 2), 'DAY') + + dt = util.DateRange(lower_bound=lower_bound_1, upper_bound=upper_bound_1) + dt2 = util.DateRange(lower_bound=lower_bound_2, upper_bound=upper_bound_1) + dt3 = util.DateRange(value=value_1) + + + + list_result = results[0].l + self.assertEqual(3, len(list_result)) + self.assertEqual(list_result[0],dt) + self.assertEqual(list_result[1],dt2) + self.assertEqual(list_result[2],dt3) + + set_result = results[0].s + self.assertEqual(len(set_result), 2) + self.assertIn(dt, set_result) + self.assertIn(dt2, set_result) + + d2i = results[0].dr2i + self.assertEqual(len(d2i), 2) + self.assertEqual(d2i[dt],1) + self.assertEqual(d2i[dt2],2) + + i2r = results[0].i2dr + self.assertEqual(len(i2r), 2) + self.assertEqual(i2r[1],dt) + self.assertEqual(i2r[2],dt2) + + def test_allow_date_range_in_udt_tuple(self): + """ + Tests DateRanges in tuples and udts + + @since 2.0.0 + @jira_ticket PYTHON-668 + @expected_result DateRanges will be inserted and retrieved successfully in udt's and tuples + @test_category data_types + """ + self.session.execute("CREATE TYPE IF NOT EXISTS test_udt (i int, range 'DateRangeType')") + self.session.execute("CREATE TABLE dateRangeIntegrationTest4 (k int PRIMARY KEY, u test_udt, uf frozen, t tuple<'DateRangeType', int>, tf frozen>)") + self.session.execute("INSERT INTO dateRangeIntegrationTest4 (k, u, uf, t, tf) VALUES (" + + "1, " + + "{i: 10, range: '[2000-01-01T10:15:30.003Z TO 2020-01-01T10:15:30.001Z]'}, " + + "{i: 20, range: '[2000-01-01T10:15:30.003Z TO 2020-01-01T10:15:30.001Z]'}, " + + "('[2000-01-01T10:15:30.003Z TO 2020-01-01T10:15:30.001Z]', 30), " + + "('[2000-01-01T10:15:30.003Z TO 2020-01-01T10:15:30.001Z]', 40))") + + lower_bound = util.DateRangeBound( + datetime(2000, 1, 1, 10, 15, 30, 3000), + 'MILLISECOND') + + upper_bound = util.DateRangeBound( + datetime(2020, 1, 1, 10, 15, 30, 1000), + 'MILLISECOND') + + expected_dt = util.DateRange(lower_bound=lower_bound ,upper_bound=upper_bound) + + results_list = list(self.session.execute("SELECT * FROM dateRangeIntegrationTest4")) + self.assertEqual(len(results_list), 1) + udt = results_list[0].u + self.assertEqual(udt.range, expected_dt) + self.assertEqual(udt.i, 10) + + + uf = results_list[0].uf + self.assertEqual(uf.range, expected_dt) + self.assertEqual(uf.i, 20) + + t = results_list[0].t + self.assertEqual(t[0], expected_dt) + self.assertEqual(t[1], 30) + + tf = results_list[0].tf + self.assertEqual(tf[0], expected_dt) + self.assertEqual(tf[1], 40) + class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): - @greaterthancass20 + @greaterthancass21 @lessthancass30 def test_nested_types_with_protocol_version(self): """ @@ -832,13 +1261,13 @@ def test_nested_types_with_protocol_version(self): self.session.execute(ddl) - for pvi in range(1, 5): + for pvi in range(3, 5): self.run_inserts_at_version(pvi) - for pvr in range(1, 5): + for pvr in range(3, 5): self.read_inserts_at_level(pvr) def read_inserts_at_level(self, proto_ver): - session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: results = session.execute('select * from t')[0] self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v)) @@ -856,7 +1285,7 @@ def read_inserts_at_level(self, proto_ver): session.cluster.shutdown() def run_inserts_at_version(self, proto_ver): - session = Cluster(protocol_version=proto_ver).connect(self.keyspace_name) + session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: p = session.prepare('insert into t (k, v) values (?, ?)') session.execute(p, (0, [{1, 2}, {3, 5}])) @@ -873,5 +1302,179 @@ def run_inserts_at_version(self, proto_ver): finally: session.cluster.shutdown() +@greaterthanorequalcass50 +class TypeTestsVector(BasicSharedKeyspaceUnitTestCase): + + def _get_first_j(self, rs): + rows = rs.all() + self.assertEqual(len(rows), 1) + return rows[0].j + + def _get_row_simple(self, idx, table_name): + rs = self.session.execute("select j from {0}.{1} where i = {2}".format(self.keyspace_name, table_name, idx)) + return self._get_first_j(rs) + + def _get_row_prepared(self, idx, table_name): + cql = "select j from {0}.{1} where i = ?".format(self.keyspace_name, table_name) + ps = self.session.prepare(cql) + rs = self.session.execute(ps, [idx]) + return self._get_first_j(rs) + def _round_trip_test(self, subtype, subtype_fn, test_fn, use_positional_parameters=True): + + table_name = subtype.replace("<","A").replace(">", "B").replace(",", "C") + "isH" + + def random_subtype_vector(): + return [subtype_fn() for _ in range(3)] + + ddl = """CREATE TABLE {0}.{1} ( + i int PRIMARY KEY, + j vector<{2}, 3>)""".format(self.keyspace_name, table_name, subtype) + self.session.execute(ddl) + if use_positional_parameters: + cql = "insert into {0}.{1} (i,j) values (%s,%s)".format(self.keyspace_name, table_name) + expected1 = random_subtype_vector() + data1 = {1:random_subtype_vector(), 2:expected1, 3:random_subtype_vector()} + for k,v in data1.items(): + # Attempt a set of inserts using the driver's support for positional params + self.session.execute(cql, (k,v)) + + cql = "insert into {0}.{1} (i,j) values (?,?)".format(self.keyspace_name, table_name) + expected2 = random_subtype_vector() + ps = self.session.prepare(cql) + data2 = {4:random_subtype_vector(), 5:expected2, 6:random_subtype_vector()} + for k,v in data2.items(): + # Add some additional rows via prepared statements + self.session.execute(ps, [k,v]) + + # Use prepared queries to gather data from the rows we added via simple queries and vice versa + if use_positional_parameters: + observed1 = self._get_row_prepared(2, table_name) + for idx in range(0, 3): + test_fn(observed1[idx], expected1[idx]) + + observed2 = self._get_row_simple(5, table_name) + for idx in range(0, 3): + test_fn(observed2[idx], expected2[idx]) + + def test_round_trip_integers(self): + self._round_trip_test("int", partial(random.randint, 0, 2 ** 31), self.assertEqual) + self._round_trip_test("bigint", partial(random.randint, 0, 2 ** 63), self.assertEqual) + self._round_trip_test("smallint", partial(random.randint, 0, 2 ** 15), self.assertEqual) + self._round_trip_test("tinyint", partial(random.randint, 0, (2 ** 7) - 1), self.assertEqual) + self._round_trip_test("varint", partial(random.randint, 0, 2 ** 63), self.assertEqual) + + def test_round_trip_floating_point(self): + _almost_equal_test_fn = partial(self.assertAlmostEqual, places=5) + def _random_decimal(): + return Decimal(random.uniform(0.0, 100.0)) + + # Max value here isn't really connected to max value for floating point nums in IEEE 754... it's used here + # mainly as a convenient benchmark + self._round_trip_test("float", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) + self._round_trip_test("double", partial(random.uniform, 0.0, 100.0), _almost_equal_test_fn) + self._round_trip_test("decimal", _random_decimal, _almost_equal_test_fn) + + def test_round_trip_text(self): + def _random_string(): + return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(24)) + + self._round_trip_test("ascii", _random_string, self.assertEqual) + self._round_trip_test("text", _random_string, self.assertEqual) + + def test_round_trip_date_and_time(self): + _almost_equal_test_fn = partial(self.assertAlmostEqual, delta=timedelta(seconds=1)) + def _random_datetime(): + return datetime.today() - timedelta(hours=random.randint(0,18), days=random.randint(1,1000)) + def _random_date(): + return _random_datetime().date() + def _random_time(): + return _random_datetime().time() + + self._round_trip_test("date", _random_date, self.assertEqual) + self._round_trip_test("time", _random_time, self.assertEqual) + self._round_trip_test("timestamp", _random_datetime, _almost_equal_test_fn) + + def test_round_trip_uuid(self): + self._round_trip_test("uuid", uuid.uuid1, self.assertEqual) + self._round_trip_test("timeuuid", uuid.uuid1, self.assertEqual) + + def test_round_trip_miscellany(self): + def _random_bytes(): + return random.getrandbits(32).to_bytes(4,'big') + def _random_boolean(): + return random.choice([True, False]) + def _random_duration(): + return Duration(random.randint(0,11), random.randint(0,11), random.randint(0,10000)) + def _random_inet(): + return socket.inet_ntoa(_random_bytes()) + + self._round_trip_test("boolean", _random_boolean, self.assertEqual) + self._round_trip_test("duration", _random_duration, self.assertEqual) + self._round_trip_test("inet", _random_inet, self.assertEqual) + self._round_trip_test("blob", _random_bytes, self.assertEqual) + + def test_round_trip_collections(self): + def _random_seq(): + return [random.randint(0,100000) for _ in range(8)] + def _random_set(): + return set(_random_seq()) + def _random_map(): + return {k:v for (k,v) in zip(_random_seq(), _random_seq())} + + # Goal here is to test collections of both fixed and variable size subtypes + self._round_trip_test("list", _random_seq, self.assertEqual) + self._round_trip_test("list", _random_seq, self.assertEqual) + self._round_trip_test("set", _random_set, self.assertEqual) + self._round_trip_test("set", _random_set, self.assertEqual) + self._round_trip_test("map", _random_map, self.assertEqual) + self._round_trip_test("map", _random_map, self.assertEqual) + self._round_trip_test("map", _random_map, self.assertEqual) + self._round_trip_test("map", _random_map, self.assertEqual) + + def test_round_trip_vector_of_vectors(self): + def _random_vector(): + return [random.randint(0,100000) for _ in range(2)] + + self._round_trip_test("vector", _random_vector, self.assertEqual) + self._round_trip_test("vector", _random_vector, self.assertEqual) + + def test_round_trip_tuples(self): + def _random_tuple(): + return (random.randint(0,100000),random.randint(0,100000)) + + # Unfortunately we can't use positional parameters when inserting tuples because the driver will try to encode + # them as lists before sending them to the server... and that confuses the parsing logic. + self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) + + def test_round_trip_udts(self): + def _udt_equal_test_fn(udt1, udt2): + self.assertEqual(udt1.a, udt2.a) + self.assertEqual(udt1.b, udt2.b) + + self.session.execute("create type {}.fixed_type (a int, b int)".format(self.keyspace_name)) + self.session.execute("create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name)) + self.session.execute("create type {}.mixed_type_two (a varint, b int)".format(self.keyspace_name)) + self.session.execute("create type {}.var_type (a varint, b varint)".format(self.keyspace_name)) + + class GeneralUDT: + def __init__(self, a, b): + self.a = a + self.b = b + + self.cluster.register_user_type(self.keyspace_name,'fixed_type', GeneralUDT) + self.cluster.register_user_type(self.keyspace_name,'mixed_type_one', GeneralUDT) + self.cluster.register_user_type(self.keyspace_name,'mixed_type_two', GeneralUDT) + self.cluster.register_user_type(self.keyspace_name,'var_type', GeneralUDT) + + def _random_udt(): + return GeneralUDT(random.randint(0,100000),random.randint(0,100000)) + + self._round_trip_test("fixed_type", _random_udt, _udt_equal_test_fn) + self._round_trip_test("mixed_type_one", _random_udt, _udt_equal_test_fn) + self._round_trip_test("mixed_type_two", _random_udt, _udt_equal_test_fn) + self._round_trip_test("var_type", _random_udt, _udt_equal_test_fn) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 8d4411a17e..9c3e560b76 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,23 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from collections import namedtuple from functools import partial -import six from cassandra import InvalidRequest -from cassandra.cluster import Cluster, UserTypeDoesNotExist +from cassandra.cluster import UserTypeDoesNotExist, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.query import dict_factory from cassandra.util import OrderedMap -from tests.integration import get_server_versions, use_singledc, PROTOCOL_VERSION, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, greaterthancass20 -from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ - get_sample, get_collection_sample +from tests.integration import use_singledc, execute_until_pass, \ + BasicSegregatedKeyspaceUnitTestCase, greaterthancass20, lessthancass30, greaterthanorequalcass36, TestCluster +from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, PRIMITIVE_DATATYPES_KEYS, \ + COLLECTION_TYPES, get_sample, get_collection_sample nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) @@ -50,13 +49,36 @@ def setUp(self): super(UDTTests, self).setUp() self.session.set_keyspace(self.keyspace_name) + @greaterthanorequalcass36 + def test_non_frozen_udts(self): + """ + Test to ensure that non frozen udt's work with C* >3.6. + + @since 3.7.0 + @jira_ticket PYTHON-498 + @expected_result Non frozen UDT's are supported + + @test_category data_types, udt + """ + self.session.execute("USE {0}".format(self.keyspace_name)) + self.session.execute("CREATE TYPE user (state text, has_corn boolean)") + self.session.execute("CREATE TABLE {0} (a int PRIMARY KEY, b user)".format(self.function_table_name)) + User = namedtuple('user', ('state', 'has_corn')) + self.cluster.register_user_type(self.keyspace_name, "user", User) + self.session.execute("INSERT INTO {0} (a, b) VALUES (%s, %s)".format(self.function_table_name), (0, User("Nebraska", True))) + self.session.execute("UPDATE {0} SET b.has_corn = False where a = 0".format(self.function_table_name)) + result = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) + self.assertFalse(result[0].b.has_corn) + table_sql = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].as_cql_query() + self.assertNotIn("", table_sql) + def test_can_insert_unprepared_registered_udts(self): """ Test the insertion of unprepared, registered UDTs """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -99,8 +121,8 @@ def test_can_register_udt_before_connecting(self): Test the registration of UDTs before session creation """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect() + c = TestCluster() + s = c.connect(wait_for_all_pools=True) s.execute(""" CREATE KEYSPACE udt_test_register_before_connecting @@ -120,7 +142,7 @@ def test_can_register_udt_before_connecting(self): # now that types are defined, shutdown and re-create Cluster c.shutdown() - c = Cluster(protocol_version=PROTOCOL_VERSION) + c = TestCluster() User1 = namedtuple('user', ('age', 'name')) User2 = namedtuple('user', ('state', 'is_cool')) @@ -128,7 +150,7 @@ def test_can_register_udt_before_connecting(self): c.register_user_type("udt_test_register_before_connecting", "user", User1) c.register_user_type("udt_test_register_before_connecting2", "user", User2) - s = c.connect() + s = c.connect(wait_for_all_pools=True) s.set_keyspace("udt_test_register_before_connecting") s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) @@ -157,8 +179,8 @@ def test_can_insert_prepared_unregistered_udts(self): Test the insertion of prepared, unregistered UDTs """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (age int, name text)") s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") @@ -201,8 +223,8 @@ def test_can_insert_prepared_registered_udts(self): Test the insertion of prepared, registered UDTs """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (age int, name text)") User = namedtuple('user', ('age', 'name')) @@ -251,8 +273,8 @@ def test_can_insert_udts_with_nulls(self): Test the insertion of UDTs with null and empty string fields """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.execute("CREATE TYPE user (a text, b int, c uuid, d blob)") User = namedtuple('user', ('a', 'b', 'c', 'd')) @@ -270,9 +292,9 @@ def test_can_insert_udts_with_nulls(self): self.assertEqual((None, None, None, None), s.execute(select)[0].b) # also test empty strings - s.execute(insert, [User('', None, None, six.binary_type())]) + s.execute(insert, [User('', None, None, bytes())]) results = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(('', None, None, six.binary_type()), results[0].b) + self.assertEqual(('', None, None, bytes()), results[0].b) c.shutdown() @@ -281,15 +303,15 @@ def test_can_insert_udts_with_varying_lengths(self): Test for ensuring extra-lengthy udts are properly inserted """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) - MAX_TEST_LENGTH = 254 + max_test_length = 254 # create the seed udt, increase timeout to avoid the query failure on slow systems s.execute("CREATE TYPE lengthy_udt ({0})" .format(', '.join(['v_{0} int'.format(i) - for i in range(MAX_TEST_LENGTH)]))) + for i in range(max_test_length)]))) # create a table with multiple sizes of nested udts # no need for all nested types, only a spot checked few and the largest one @@ -298,13 +320,13 @@ def test_can_insert_udts_with_varying_lengths(self): "v frozen)") # create and register the seed udt type - udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(MAX_TEST_LENGTH)])) + udt = namedtuple('lengthy_udt', tuple(['v_{0}'.format(i) for i in range(max_test_length)])) c.register_user_type(self.keyspace_name, "lengthy_udt", udt) # verify inserts and reads - for i in (0, 1, 2, 3, MAX_TEST_LENGTH): + for i in (0, 1, 2, 3, max_test_length): # create udt - params = [j for j in range(i)] + [None for j in range(MAX_TEST_LENGTH - i)] + params = [j for j in range(i)] + [None for j in range(max_test_length - i)] created_udt = udt(*params) # write udt @@ -316,12 +338,12 @@ def test_can_insert_udts_with_varying_lengths(self): c.shutdown() - def nested_udt_schema_helper(self, session, MAX_NESTING_DEPTH): + def nested_udt_schema_helper(self, session, max_nesting_depth): # create the seed udt execute_until_pass(session, "CREATE TYPE depth_0 (age int, name text)") # create the nested udts - for i in range(MAX_NESTING_DEPTH): + for i in range(max_nesting_depth): execute_until_pass(session, "CREATE TYPE depth_{0} (value frozen)".format(i + 1, i)) # create a table with multiple sizes of nested udts @@ -332,7 +354,7 @@ def nested_udt_schema_helper(self, session, MAX_NESTING_DEPTH): "v_1 frozen, " "v_2 frozen, " "v_3 frozen, " - "v_{0} frozen)".format(MAX_NESTING_DEPTH)) + "v_{0} frozen)".format(max_nesting_depth)) def nested_udt_creation_helper(self, udts, i): if i == 0: @@ -340,8 +362,8 @@ def nested_udt_creation_helper(self, udts, i): else: return udts[i](self.nested_udt_creation_helper(udts, i - 1)) - def nested_udt_verification_helper(self, session, MAX_NESTING_DEPTH, udts): - for i in (0, 1, 2, 3, MAX_NESTING_DEPTH): + def nested_udt_verification_helper(self, session, max_nesting_depth, udts): + for i in (0, 1, 2, 3, max_nesting_depth): # create udt udt = self.nested_udt_creation_helper(udts, i) @@ -360,75 +382,73 @@ def nested_udt_verification_helper(self, session, MAX_NESTING_DEPTH, udts): result = session.execute("SELECT v_{0} FROM mytable WHERE k=1".format(i))[0] self.assertEqual(udt, result["v_{0}".format(i)]) + def _cluster_default_dict_factory(self): + return TestCluster( + execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)} + ) + def test_can_insert_nested_registered_udts(self): """ Test for ensuring nested registered udts are properly inserted """ + with self._cluster_default_dict_factory() as c: + s = c.connect(self.keyspace_name, wait_for_all_pools=True) - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) - s.row_factory = dict_factory + max_nesting_depth = 16 - MAX_NESTING_DEPTH = 16 + # create the schema + self.nested_udt_schema_helper(s, max_nesting_depth) - # create the schema - self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) - - # create and register the seed udt type - udts = [] - udt = namedtuple('depth_0', ('age', 'name')) - udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_0", udts[0]) - - # create and register the nested udt types - for i in range(MAX_NESTING_DEPTH): - udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + # create and register the seed udt type + udts = [] + udt = namedtuple('depth_0', ('age', 'name')) udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) - # insert udts and verify inserts with reads - self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) + # create and register the nested udt types + for i in range(max_nesting_depth): + udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) - c.shutdown() + # insert udts and verify inserts with reads + self.nested_udt_verification_helper(s, max_nesting_depth, udts) def test_can_insert_nested_unregistered_udts(self): """ Test for ensuring nested unregistered udts are properly inserted """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) - s.row_factory = dict_factory + with self._cluster_default_dict_factory() as c: + s = c.connect(self.keyspace_name, wait_for_all_pools=True) - MAX_NESTING_DEPTH = 16 + max_nesting_depth = 16 - # create the schema - self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) + # create the schema + self.nested_udt_schema_helper(s, max_nesting_depth) - # create the seed udt type - udts = [] - udt = namedtuple('depth_0', ('age', 'name')) - udts.append(udt) - - # create the nested udt types - for i in range(MAX_NESTING_DEPTH): - udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + # create the seed udt type + udts = [] + udt = namedtuple('depth_0', ('age', 'name')) udts.append(udt) - # insert udts via prepared statements and verify inserts with reads - for i in (0, 1, 2, 3, MAX_NESTING_DEPTH): - # create udt - udt = self.nested_udt_creation_helper(udts, i) + # create the nested udt types + for i in range(max_nesting_depth): + udt = namedtuple('depth_{0}'.format(i + 1), ('value')) + udts.append(udt) - # write udt - insert = s.prepare("INSERT INTO mytable (k, v_{0}) VALUES (0, ?)".format(i)) - s.execute(insert, [udt]) + # insert udts via prepared statements and verify inserts with reads + for i in (0, 1, 2, 3, max_nesting_depth): + # create udt + udt = self.nested_udt_creation_helper(udts, i) - # verify udt was written and read correctly - result = s.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i))[0] - self.assertEqual(udt, result["v_{0}".format(i)]) + # write udt + insert = s.prepare("INSERT INTO mytable (k, v_{0}) VALUES (0, ?)".format(i)) + s.execute(insert, [udt]) - c.shutdown() + # verify udt was written and read correctly + result = s.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i))[0] + self.assertEqual(udt, result["v_{0}".format(i)]) def test_can_insert_nested_registered_udts_with_different_namedtuples(self): """ @@ -436,39 +456,36 @@ def test_can_insert_nested_registered_udts_with_different_namedtuples(self): created namedtuples are use names that are different the cql type. """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) - s.row_factory = dict_factory + with self._cluster_default_dict_factory() as c: + s = c.connect(self.keyspace_name, wait_for_all_pools=True) - MAX_NESTING_DEPTH = 16 + max_nesting_depth = 16 - # create the schema - self.nested_udt_schema_helper(s, MAX_NESTING_DEPTH) + # create the schema + self.nested_udt_schema_helper(s, max_nesting_depth) - # create and register the seed udt type - udts = [] - udt = namedtuple('level_0', ('age', 'name')) - udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_0", udts[0]) - - # create and register the nested udt types - for i in range(MAX_NESTING_DEPTH): - udt = namedtuple('level_{0}'.format(i + 1), ('value')) + # create and register the seed udt type + udts = [] + udt = namedtuple('level_0', ('age', 'name')) udts.append(udt) - c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) + c.register_user_type(self.keyspace_name, "depth_0", udts[0]) - # insert udts and verify inserts with reads - self.nested_udt_verification_helper(s, MAX_NESTING_DEPTH, udts) + # create and register the nested udt types + for i in range(max_nesting_depth): + udt = namedtuple('level_{0}'.format(i + 1), ('value')) + udts.append(udt) + c.register_user_type(self.keyspace_name, "depth_{0}".format(i + 1), udts[i + 1]) - c.shutdown() + # insert udts and verify inserts with reads + self.nested_udt_verification_helper(s, max_nesting_depth, udts) def test_raise_error_on_nonexisting_udts(self): """ Test for ensuring that an error is raised for operating on a nonexisting udt or an invalid keyspace """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) User = namedtuple('user', ('age', 'name')) with self.assertRaises(UserTypeDoesNotExist): @@ -487,8 +504,8 @@ def test_can_insert_udt_all_datatypes(self): Test for inserting various types of PRIMITIVE_DATATYPES into UDT's """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) # create UDT alpha_type_list = [] @@ -532,14 +549,14 @@ def test_can_insert_udt_all_collection_datatypes(self): Test for inserting various types of COLLECTION_TYPES into UDT's """ - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) # create UDT alpha_type_list = [] start_index = ord('a') for i, collection_type in enumerate(COLLECTION_TYPES): - for j, datatype in enumerate(PRIMITIVE_DATATYPES): + for j, datatype in enumerate(PRIMITIVE_DATATYPES_KEYS): if collection_type == "map": type_string = "{0}_{1} {2}<{3}, {3}>".format(chr(start_index + i), chr(start_index + j), collection_type, datatype) @@ -561,7 +578,7 @@ def test_can_insert_udt_all_collection_datatypes(self): # register UDT alphabet_list = [] for i in range(ord('a'), ord('a') + len(COLLECTION_TYPES)): - for j in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES)): + for j in range(ord('a'), ord('a') + len(PRIMITIVE_DATATYPES_KEYS)): alphabet_list.append('{0}_{1}'.format(chr(i), chr(j))) Alldatatypes = namedtuple("alldatatypes", alphabet_list) @@ -570,7 +587,7 @@ def test_can_insert_udt_all_collection_datatypes(self): # insert UDT data params = [] for collection_type in COLLECTION_TYPES: - for datatype in PRIMITIVE_DATATYPES: + for datatype in PRIMITIVE_DATATYPES_KEYS: params.append((get_collection_sample(collection_type, datatype))) insert = s.prepare("INSERT INTO mytable (a, b) VALUES (?, ?)") @@ -599,8 +616,8 @@ def test_can_insert_nested_collections(self): if self.cass_version < (2, 1, 3): raise unittest.SkipTest("Support for nested collections was introduced in Cassandra 2.1.3") - c = Cluster(protocol_version=PROTOCOL_VERSION) - s = c.connect(self.keyspace_name) + c = TestCluster() + s = c.connect(self.keyspace_name, wait_for_all_pools=True) s.encoder.mapping[tuple] = s.encoder.cql_encode_tuple name = self._testMethodName @@ -668,7 +685,11 @@ def test_non_alphanum_identifiers(self): self.assertEqual(k[0], 'alphanum') self.assertEqual(k.field_0_, 'alphanum') # named tuple with positional field name + @lessthancass30 def test_type_alteration(self): + """ + Support for ALTER TYPE was removed in CASSANDRA-12443 + """ s = self.session type_name = "type_name" self.assertNotIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) @@ -676,7 +697,7 @@ def test_type_alteration(self): self.assertIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) s.execute('CREATE TABLE %s (k int PRIMARY KEY, v frozen<%s>)' % (self.table_name, type_name)) - s.execute('INSERT INTO %s (k, v) VALUES (0, 1)' % (self.table_name,)) + s.execute('INSERT INTO %s (k, v) VALUES (0, {v0 : 1})' % (self.table_name,)) s.cluster.register_user_type('udttests', type_name, dict) @@ -688,18 +709,19 @@ def test_type_alteration(self): val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] self.assertEqual(val['v0'], 1) self.assertIsNone(val['v1']) - s.execute("INSERT INTO %s (k, v) VALUES (0, (2, 'sometext'))" % (self.table_name,)) + s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 2, v1 : 'sometext'})" % (self.table_name,)) val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] self.assertEqual(val['v0'], 2) self.assertEqual(val['v1'], 'sometext') # alter field type s.execute('ALTER TYPE %s ALTER v1 TYPE blob' % (type_name,)) - s.execute("INSERT INTO %s (k, v) VALUES (0, (3, 0xdeadbeef))" % (self.table_name,)) + s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 3, v1 : 0xdeadbeef})" % (self.table_name,)) val = s.execute('SELECT v FROM %s' % self.table_name)[0][0] self.assertEqual(val['v0'], 3) - self.assertEqual(val['v1'], six.b('\xde\xad\xbe\xef')) + self.assertEqual(val['v1'], b'\xde\xad\xbe\xef') + @lessthancass30 def test_alter_udt(self): """ Test to ensure that altered UDT's are properly surfaced without needing to restart the underlying session. @@ -731,4 +753,3 @@ def test_alter_udt(self): for result in results: self.assertTrue(hasattr(result.typetoalter, 'a')) self.assertTrue(hasattr(result.typetoalter, 'b')) - diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py index 8749c5e3a2..917b3a7f6e 100644 --- a/tests/integration/standard/utils.py +++ b/tests/integration/standard/utils.py @@ -4,6 +4,7 @@ from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample + def create_table_with_all_types(table_name, session, N): """ Method that given a table_name and session construct a table that contains @@ -45,7 +46,11 @@ def get_all_primitive_params(key): """ params = [key] for datatype in PRIMITIVE_DATATYPES: - params.append(get_sample(datatype)) + # Also test for empty strings + if key == 1 and datatype == 'ascii': + params.append('') + else: + params.append(get_sample(datatype)) return params diff --git a/tests/integration/upgrade/__init__.py b/tests/integration/upgrade/__init__.py new file mode 100644 index 0000000000..5dfb4fecf8 --- /dev/null +++ b/tests/integration/upgrade/__init__.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests.integration import CCM_KWARGS, use_cluster, remove_cluster, MockLoggingHandler +from tests.integration import setup_keyspace + +from cassandra.cluster import Cluster +from cassandra import cluster + +from collections import namedtuple +from functools import wraps +import logging +from threading import Thread, Event +from ccmlib.node import TimeoutError +import time +import logging + +import unittest + + +def setup_module(): + remove_cluster() + + +UPGRADE_CLUSTER_NAME = "upgrade_cluster" +UpgradePath = namedtuple('UpgradePath', ('name', 'starting_version', 'upgrade_version', 'configuration_options')) + +log = logging.getLogger(__name__) + + +class upgrade_paths(object): + """ + Decorator used to specify the upgrade paths for a particular method + """ + def __init__(self, paths): + self.paths = paths + + def __call__(self, method): + @wraps(method) + def wrapper(*args, **kwargs): + for path in self.paths: + self_from_decorated = args[0] + log.debug('setting up {path}'.format(path=path)) + self_from_decorated.UPGRADE_PATH = path + self_from_decorated._upgrade_step_setup() + method(*args, **kwargs) + log.debug('tearing down {path}'.format(path=path)) + self_from_decorated._upgrade_step_teardown() + return wrapper + + +class UpgradeBase(unittest.TestCase): + """ + Base class for the upgrade tests. The _setup method + will clean the environment and start the appropriate C* version according + to the upgrade path. The upgrade can be done in a different thread using the + start_upgrade upgrade_method (this would be the most realistic scenario) + or node by node, waiting for the upgrade to happen, using _upgrade_one_node method + """ + UPGRADE_PATH = None + start_cluster = True + set_keyspace = True + + @classmethod + def setUpClass(cls): + cls.logger_handler = MockLoggingHandler() + logger = logging.getLogger(cluster.__name__) + logger.addHandler(cls.logger_handler) + + def _upgrade_step_setup(self): + """ + This is not the regular _setUp method because it will be called from + the decorator instead of letting nose handle it. + This setup method will start a cluster with the right version according + to the variable UPGRADE_PATH. + """ + remove_cluster() + self.cluster = use_cluster(UPGRADE_CLUSTER_NAME + self.UPGRADE_PATH.name, [3], + ccm_options=self.UPGRADE_PATH.starting_version, set_keyspace=self.set_keyspace, + configuration_options=self.UPGRADE_PATH.configuration_options) + self.nodes = self.cluster.nodelist() + self.last_node_upgraded = None + self.upgrade_done = Event() + self.upgrade_thread = None + + if self.start_cluster: + setup_keyspace() + + self.cluster_driver = Cluster() + self.session = self.cluster_driver.connect() + self.logger_handler.reset() + + def _upgrade_step_teardown(self): + """ + special tearDown method called by the decorator after the method has ended + """ + if self.upgrade_thread: + self.upgrade_thread.join(timeout=5) + self.upgrade_thread = None + + if self.start_cluster: + self.cluster_driver.shutdown() + + def start_upgrade(self, time_node_upgrade): + """ + Starts the upgrade in a different thread + """ + log.debug('Starting upgrade in new thread') + self.upgrade_thread = Thread(target=self._upgrade, args=(time_node_upgrade,)) + self.upgrade_thread.start() + + def _upgrade(self, time_node_upgrade): + """ + Starts the upgrade in the same thread + """ + start_time = time.time() + for node in self.nodes: + self.upgrade_node(node) + end_time = time.time() + time_to_upgrade = end_time - start_time + if time_node_upgrade > time_to_upgrade: + time.sleep(time_node_upgrade - time_to_upgrade) + self.upgrade_done.set() + + def is_upgraded(self): + """ + Returns True if the upgrade has finished and False otherwise + """ + return self.upgrade_done.is_set() + + def wait_for_upgrade(self, timeout=None): + """ + Waits until the upgrade has completed + """ + self.upgrade_done.wait(timeout=timeout) + + def upgrade_node(self, node): + """ + Upgrades only one node. Return True if the upgrade + has finished and False otherwise + """ + node.drain() + node.stop(gently=True) + + node.set_install_dir(**self.UPGRADE_PATH.upgrade_version) + + # There must be a cleaner way of doing this, but it's necessary here + # to call the private method from cluster __update_topology_files + self.cluster._Cluster__update_topology_files() + try: + node.start(wait_for_binary_proto=True, wait_other_notice=True) + except TimeoutError: + self.fail("Error starting C* node while upgrading") + + return True + + +class UpgradeBaseAuth(UpgradeBase): + """ + Base class of authentication test, the authentication parameters for + C* still have to be specified within the upgrade path variable + """ + start_cluster = False + set_keyspace = False + + + def _upgrade_step_setup(self): + """ + We sleep here for the same reason as we do in test_authentication.py: + there seems to be some race, with some versions of C* taking longer to + get the auth (and default user) setup. Sleep here to give it a chance + """ + super(UpgradeBaseAuth, self)._upgrade_step_setup() + time.sleep(10) diff --git a/tests/integration/upgrade/test_upgrade.py b/tests/integration/upgrade/test_upgrade.py new file mode 100644 index 0000000000..837d8232cb --- /dev/null +++ b/tests/integration/upgrade/test_upgrade.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from itertools import count + +from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider +from cassandra.cluster import ConsistencyLevel, Cluster, DriverException, ExecutionProfile +from cassandra.policies import ConstantSpeculativeExecutionPolicy +from tests.integration.upgrade import UpgradeBase, UpgradeBaseAuth, UpgradePath, upgrade_paths + +import unittest + + +# Previous Cassandra upgrade +two_to_three_path = upgrade_paths([ + UpgradePath("2.2.9-3.11", {"version": "2.2.9"}, {"version": "3.11.4"}, {}), +]) + +# Previous DSE upgrade +five_upgrade_path = upgrade_paths([ + UpgradePath("5.0.11-5.1.4", {"version": "5.0.11"}, {"version": "5.1.4"}, {}), +]) + + +class UpgradeTests(UpgradeBase): + @two_to_three_path + def test_can_write(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + self.start_upgrade(0) + + self.cluster_driver.add_execution_profile("all", ExecutionProfile(consistency_level=ConsistencyLevel.ALL)) + self.cluster_driver.add_execution_profile("one", ExecutionProfile(consistency_level=ConsistencyLevel.LOCAL_ONE)) + + c = count() + while not self.is_upgraded(): + self.session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c), ), execution_profile="one") + time.sleep(0.0001) + + total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + + @two_to_three_path + def test_can_connect(self): + """ + Verify that the driver can connect to all the nodes + despite some nodes being in different versions + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + def connect_and_shutdown(): + cluster = Cluster() + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + results = session.execute("SELECT * from system.local") + self.assertGreater(len(results.current_rows), 0) + self.assertEqual(len(results.response_future.attempted_hosts), 1) + queried_hosts.add(results.response_future.attempted_hosts[0]) + self.assertEqual(len(queried_hosts), 3) + cluster.shutdown() + + connect_and_shutdown() + for node in self.nodes: + self.upgrade_node(node) + connect_and_shutdown() + + connect_and_shutdown() + + +class UpgradeTestsMetadata(UpgradeBase): + @two_to_three_path + def test_can_write(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + self.start_upgrade(0) + + self.cluster_driver.add_execution_profile("all", ExecutionProfile(consistency_level=ConsistencyLevel.ALL)) + self.cluster_driver.add_execution_profile("one", ExecutionProfile(consistency_level=ConsistencyLevel.LOCAL_ONE)) + + c = count() + while not self.is_upgraded(): + self.session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c),), execution_profile="one") + time.sleep(0.0001) + + total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + + @two_to_three_path + def test_schema_metadata_gets_refreshed(self): + """ + Verify that the driver fails to update the metadata while connected against + different versions of nodes. This won't succeed because each node will report a + different schema version + + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver raises DriverException when updating the schema + metadata while upgrading + all the hosts + + @test_category metadata + """ + original_meta = self.cluster_driver.metadata.keyspaces + number_of_nodes = len(self.cluster.nodelist()) + nodes = self.nodes + for node in nodes[1:]: + self.upgrade_node(node) + # Wait for the control connection to reconnect + time.sleep(20) + + with self.assertRaises(DriverException): + self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=10) + + self.upgrade_node(nodes[0]) + # Wait for the control connection to reconnect + time.sleep(20) + self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=40) + self.assertNotEqual(original_meta, self.cluster_driver.metadata.keyspaces) + + @two_to_three_path + def test_schema_nodes_gets_refreshed(self): + """ + Verify that the driver token map and node list gets rebuild correctly while upgrading. + The token map and the node list should be the same after each node upgrade + + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the token map and the node list stays consistent with each node upgrade + metadata while upgrading + all the hosts + + @test_category metadata + """ + for node in self.nodes: + token_map = self.cluster_driver.metadata.token_map + self.upgrade_node(node) + # Wait for the control connection to reconnect + time.sleep(20) + + self.cluster_driver.refresh_nodes(force_token_rebuild=True) + self._assert_same_token_map(token_map, self.cluster_driver.metadata.token_map) + + def _assert_same_token_map(self, original, new): + self.assertIsNot(original, new) + self.assertEqual(original.tokens_to_hosts_by_ks, new.tokens_to_hosts_by_ks) + self.assertEqual(original.token_to_host_owner, new.token_to_host_owner) + self.assertEqual(original.ring, new.ring) + + +two_to_three_with_auth_path = upgrade_paths([ + UpgradePath("2.2.9-3.11-auth", {"version": "2.2.9"}, {"version": "3.11.4"}, + {'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer'}), +]) +class UpgradeTestsAuthentication(UpgradeBaseAuth): + @two_to_three_with_auth_path + def test_can_connect_auth_plain(self): + """ + Verify that the driver can connect despite some nodes being in different versions + with plain authentication + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + auth_provider = PlainTextAuthProvider( + username="cassandra", + password="cassandra" + ) + self.connect_and_shutdown(auth_provider) + for node in self.nodes: + self.upgrade_node(node) + self.connect_and_shutdown(auth_provider) + + self.connect_and_shutdown(auth_provider) + + @two_to_three_with_auth_path + def test_can_connect_auth_sasl(self): + """ + Verify that the driver can connect despite some nodes being in different versions + with ssl authentication + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result the driver connects successfully and can execute queries against + all the hosts + + @test_category upgrade + """ + sasl_kwargs = {'service': 'cassandra', + 'mechanism': 'PLAIN', + 'qops': ['auth'], + 'username': 'cassandra', + 'password': 'cassandra'} + auth_provider = SaslAuthProvider(**sasl_kwargs) + self.connect_and_shutdown(auth_provider) + for node in self.nodes: + self.upgrade_node(node) + self.connect_and_shutdown(auth_provider) + + self.connect_and_shutdown(auth_provider) + + def connect_and_shutdown(self, auth_provider): + cluster = Cluster(idle_heartbeat_interval=0, + auth_provider=auth_provider) + session = cluster.connect(wait_for_all_pools=True) + queried_hosts = set() + for _ in range(10): + results = session.execute("SELECT * from system.local") + self.assertGreater(len(results.current_rows), 0) + self.assertEqual(len(results.response_future.attempted_hosts), 1) + queried_hosts.add(results.response_future.attempted_hosts[0]) + self.assertEqual(len(queried_hosts), 3) + cluster.shutdown() + + +class UpgradeTestsPolicies(UpgradeBase): + @two_to_three_path + def test_can_write_speculative(self): + """ + Verify that the driver will keep querying C* even if there is a host down while being + upgraded and that all the writes will eventually succeed using the ConstantSpeculativeExecutionPolicy + policy + @since 3.12 + @jira_ticket PYTHON-546 + @expected_result all the writes succeed + + @test_category upgrade + """ + spec_ep_rr = ExecutionProfile(speculative_execution_policy=ConstantSpeculativeExecutionPolicy(.5, 10), + request_timeout=12) + cluster = Cluster() + self.addCleanup(cluster.shutdown) + cluster.add_execution_profile("spec_ep_rr", spec_ep_rr) + cluster.add_execution_profile("all", ExecutionProfile(consistency_level=ConsistencyLevel.ALL)) + session = cluster.connect() + + self.start_upgrade(0) + + c = count() + while not self.is_upgraded(): + session.execute("INSERT INTO test3rf.test(k, v) VALUES (%s, 0)", (next(c),), + execution_profile='spec_ep_rr') + time.sleep(0.0001) + + total_number_of_inserted = session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all")[0][0] + self.assertEqual(total_number_of_inserted, next(c)) + + self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) diff --git a/tests/integration/util.py b/tests/integration/util.py index ff71185a4a..efdf258b2b 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,9 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from itertools import chain + from tests.integration import PROTOCOL_VERSION +import time + -def assert_quiescent_pool_state(test_case, cluster): +def assert_quiescent_pool_state(test_case, cluster, wait=None): + """ + Checking the quiescent pool state checks that none of the requests ids have + been lost. However, the callback corresponding to a request_id is called + before the request_id is returned back to the pool, therefore + + session.execute("SELECT * from system.local") + assert_quiescent_pool_state(self, session.cluster) + + (with no wait) might fail because when execute comes back the request_id + hasn't yet been returned to the pool, therefore the wait. + """ + if wait is not None: + time.sleep(wait) for session in cluster.sessions: pool_states = session.get_pool_state().values() @@ -23,15 +42,18 @@ def assert_quiescent_pool_state(test_case, cluster): for state in pool_states: test_case.assertFalse(state['shutdown']) test_case.assertGreater(state['open_count'], 0) - test_case.assertTrue(all((i == 0 for i in state['in_flights']))) + no_in_flight = all((i == 0 for i in state['in_flights'])) + orphans_and_inflights = zip(state['orphan_requests'], state['in_flights']) + all_orphaned = all((len(orphans) == inflight for (orphans, inflight) in orphans_and_inflights)) + test_case.assertTrue(no_in_flight or all_orphaned) for holder in cluster.get_connection_holders(): for connection in holder.get_connections(): # all ids are unique req_ids = connection.request_ids + orphan_ids = connection.orphaned_request_ids test_case.assertEqual(len(req_ids), len(set(req_ids))) - test_case.assertEqual(connection.highest_request_id, len(req_ids) - 1) - test_case.assertEqual(connection.highest_request_id, max(req_ids)) + test_case.assertEqual(connection.highest_request_id, len(req_ids) + len(orphan_ids) - 1) + test_case.assertEqual(connection.highest_request_id, max(chain(req_ids, orphan_ids))) if PROTOCOL_VERSION < 3: test_case.assertEqual(connection.highest_request_id, connection.max_request_id) - diff --git a/tests/stress_tests/test_load.py b/tests/stress_tests/test_load.py index 523ee2004b..30c384f098 100644 --- a/tests/stress_tests/test_load.py +++ b/tests/stress_tests/test_load.py @@ -1,20 +1,19 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest import gc diff --git a/tests/stress_tests/test_multi_inserts.py b/tests/stress_tests/test_multi_inserts.py index e39e73e8b7..3e32e233f1 100644 --- a/tests/stress_tests/test_multi_inserts.py +++ b/tests/stress_tests/test_multi_inserts.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest import os from cassandra.cluster import Cluster diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/advanced/__init__.py b/tests/unit/advanced/__init__.py new file mode 100644 index 0000000000..635f0d9e60 --- /dev/null +++ b/tests/unit/advanced/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/advanced/cloud/__init__.py b/tests/unit/advanced/cloud/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/advanced/cloud/creds.zip b/tests/unit/advanced/cloud/creds.zip new file mode 100644 index 0000000000..6bd8faf69f Binary files /dev/null and b/tests/unit/advanced/cloud/creds.zip differ diff --git a/tests/unit/advanced/cloud/test_cloud.py b/tests/unit/advanced/cloud/test_cloud.py new file mode 100644 index 0000000000..04cbf883f3 --- /dev/null +++ b/tests/unit/advanced/cloud/test_cloud.py @@ -0,0 +1,106 @@ +# Copyright DataStax, Inc. +# +# Licensed under the DataStax DSE Driver License; +# you may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.datastax.com/terms/datastax-dse-driver-license-terms +import tempfile +import os +import shutil +import unittest +from unittest.mock import patch + +from cassandra import DriverException +from cassandra.datastax import cloud + +from tests import notwindows + +class CloudTests(unittest.TestCase): + + current_path = os.path.dirname(os.path.abspath(__file__)) + creds_path = os.path.join(current_path, './creds.zip') + config_zip = { + 'secure_connect_bundle': creds_path + } + metadata_json = """ + {"region":"local", + "contact_info": { + "type":"sni_proxy", + "local_dc":"dc1", + "contact_points":[ + "b13ae7b4-e711-4660-8dd1-bec57d37aa64", + "d4330144-5fb3-425a-86a1-431b3e4d0671", + "86537b87-91a9-4c59-b715-716486e72c42" + ], + "sni_proxy_address":"localhost:30002" + } + }""" + + @staticmethod + def _read_metadata_info_side_effect(config, _): + return config + + def _check_config(self, config): + self.assertEqual(config.username, 'cassandra') + self.assertEqual(config.password, 'cassandra') + self.assertEqual(config.host, 'localhost') + self.assertEqual(config.port, 30443) + self.assertEqual(config.keyspace, 'system') + self.assertEqual(config.local_dc, None) + self.assertIsNotNone(config.ssl_context) + self.assertIsNone(config.sni_host) + self.assertIsNone(config.sni_port) + self.assertIsNone(config.host_ids) + + def test_read_cloud_config_from_zip(self): + + with patch('cassandra.datastax.cloud.read_metadata_info', side_effect=self._read_metadata_info_side_effect): + config = cloud.get_cloud_config(self.config_zip) + + self._check_config(config) + + def test_parse_metadata_info(self): + config = cloud.CloudConfig() + cloud.parse_metadata_info(config, self.metadata_json) + self.assertEqual(config.sni_host, 'localhost') + self.assertEqual(config.sni_port, 30002) + self.assertEqual(config.local_dc, 'dc1') + + host_ids = [ + "b13ae7b4-e711-4660-8dd1-bec57d37aa64", + "d4330144-5fb3-425a-86a1-431b3e4d0671", + "86537b87-91a9-4c59-b715-716486e72c42" + ] + for host_id in host_ids: + self.assertIn(host_id, config.host_ids) + + @notwindows + def test_use_default_tempdir(self): + tmpdir = tempfile.mkdtemp() + + def clean_tmp_dir(): + os.chmod(tmpdir, 0o777) + shutil.rmtree(tmpdir) + self.addCleanup(clean_tmp_dir) + + tmp_creds_path = os.path.join(tmpdir, 'creds.zip') + shutil.copyfile(self.creds_path, tmp_creds_path) + os.chmod(tmpdir, 0o544) + config = { + 'secure_connect_bundle': tmp_creds_path + } + + # The directory is not writtable.. we expect a permission error + with self.assertRaises(PermissionError): + cloud.get_cloud_config(config) + + # With use_default_tempdir, we expect an connection refused + # since the cluster doesn't exist + with self.assertRaises(DriverException): + config = { + 'secure_connect_bundle': tmp_creds_path, + 'use_default_tempdir': True + } + cloud.get_cloud_config(config) diff --git a/tests/unit/advanced/test_auth.py b/tests/unit/advanced/test_auth.py new file mode 100644 index 0000000000..6457810a6f --- /dev/null +++ b/tests/unit/advanced/test_auth.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from puresasl import QOP + +import unittest + +from cassandra.auth import DSEGSSAPIAuthProvider + +# Cannot import requiredse from tests.integration +# This auth provider requires kerberos and puresals +DSE_VERSION = os.getenv('DSE_VERSION', None) +@unittest.skipUnless(DSE_VERSION, "DSE required") +class TestGSSAPI(unittest.TestCase): + + def test_host_resolution(self): + # resolves by default + provider = DSEGSSAPIAuthProvider(service='test', qops=QOP.all) + authenticator = provider.new_authenticator('127.0.0.1') + self.assertEqual(authenticator.sasl.host, 'localhost') + + # numeric fallback okay + authenticator = provider.new_authenticator('192.0.2.1') + self.assertEqual(authenticator.sasl.host, '192.0.2.1') + + # disable explicitly + provider = DSEGSSAPIAuthProvider(service='test', qops=QOP.all, resolve_host_name=False) + authenticator = provider.new_authenticator('127.0.0.1') + self.assertEqual(authenticator.sasl.host, '127.0.0.1') + diff --git a/tests/unit/advanced/test_execution_profile.py b/tests/unit/advanced/test_execution_profile.py new file mode 100644 index 0000000000..143a391f72 --- /dev/null +++ b/tests/unit/advanced/test_execution_profile.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.cluster import GraphExecutionProfile, GraphAnalyticsExecutionProfile +from cassandra.graph import GraphOptions + + +class GraphExecutionProfileTest(unittest.TestCase): + + def test_graph_source_can_be_set_with_graph_execution_profile(self): + options = GraphOptions(graph_source='a') + ep = GraphExecutionProfile(graph_options=options) + self.assertEqual(ep.graph_options.graph_source, b'a') + + def test_graph_source_is_preserve_with_graph_analytics_execution_profile(self): + options = GraphOptions(graph_source='doesnt_matter') + ep = GraphAnalyticsExecutionProfile(graph_options=options) + self.assertEqual(ep.graph_options.graph_source, b'a') # graph source is set automatically diff --git a/tests/unit/advanced/test_geometry.py b/tests/unit/advanced/test_geometry.py new file mode 100644 index 0000000000..0e5dc8f93f --- /dev/null +++ b/tests/unit/advanced/test_geometry.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import struct +import math +from cassandra.cqltypes import lookup_casstype +from cassandra.protocol import ProtocolVersion +from cassandra.cqltypes import PointType, LineStringType, PolygonType, WKBGeometryType +from cassandra.util import Point, LineString, Polygon, _LinearRing, Distance, _HAS_GEOMET + +wkb_be = 0 +wkb_le = 1 + +protocol_versions = ProtocolVersion.SUPPORTED_VERSIONS + + +class GeoTypes(unittest.TestCase): + + samples = (Point(1, 2), LineString(((1, 2), (3, 4), (5, 6))), Polygon([(10.1, 10.0), (110.0, 10.0), (110., 110.0), (10., 110.0), (10., 10.0)], [[(20., 20.0), (20., 30.0), (30., 30.0), (30., 20.0), (20., 20.0)], [(40., 20.0), (40., 30.0), (50., 30.0), (50., 20.0), (40., 20.0)]])) + + def test_marshal_platform(self): + for proto_ver in protocol_versions: + for geo in self.samples: + cql_type = lookup_casstype(geo.__class__.__name__ + 'Type') + self.assertEqual(cql_type.from_binary(cql_type.to_binary(geo, proto_ver), proto_ver), geo) + + def _verify_both_endian(self, typ, body_fmt, params, expected): + for proto_ver in protocol_versions: + self.assertEqual(typ.from_binary(struct.pack(">BI" + body_fmt, wkb_be, *params), proto_ver), expected) + self.assertEqual(typ.from_binary(struct.pack(" base map + base = GraphOptions(**self.api_params) + self.assertEqual(GraphOptions().get_options_map(base), base._graph_options) + + # something set overrides + kwargs = self.api_params.copy() # this test concept got strange after we added default values for a couple GraphOption attrs + kwargs['graph_name'] = 'unit_test' + other = GraphOptions(**kwargs) + options = base.get_options_map(other) + updated = self.opt_mapping['graph_name'] + self.assertEqual(options[updated], b'unit_test') + for name in (n for n in self.opt_mapping.values() if n != updated): + self.assertEqual(options[name], base._graph_options[name]) + + # base unchanged + self._verify_api_params(base, self.api_params) + + def test_set_attr(self): + expected = 'test@@@@' + opts = GraphOptions(graph_name=expected) + self.assertEqual(opts.graph_name, expected.encode()) + expected = 'somethingelse####' + opts.graph_name = expected + self.assertEqual(opts.graph_name, expected.encode()) + + # will update options with set value + another = GraphOptions() + self.assertIsNone(another.graph_name) + another.update(opts) + self.assertEqual(another.graph_name, expected.encode()) + + opts.graph_name = None + self.assertIsNone(opts.graph_name) + # will not update another with its set-->unset value + another.update(opts) + self.assertEqual(another.graph_name, expected.encode()) # remains unset + opt_map = another.get_options_map(opts) + self.assertEqual(opt_map, another._graph_options) + + def test_del_attr(self): + opts = GraphOptions(**self.api_params) + test_params = self.api_params.copy() + del test_params['graph_source'] + del opts.graph_source + self._verify_api_params(opts, test_params) + + def _verify_api_params(self, opts, api_params): + self.assertEqual(len(opts._graph_options), len(api_params)) + for name, value in api_params.items(): + try: + value = value.encode() + except: + pass # already bytes + self.assertEqual(getattr(opts, name), value) + self.assertEqual(opts._graph_options[self.opt_mapping[name]], value) + + def test_consistency_levels(self): + read_cl = ConsistencyLevel.ONE + write_cl = ConsistencyLevel.LOCAL_QUORUM + + # set directly + opts = GraphOptions(graph_read_consistency_level=read_cl, graph_write_consistency_level=write_cl) + self.assertEqual(opts.graph_read_consistency_level, read_cl) + self.assertEqual(opts.graph_write_consistency_level, write_cl) + + # mapping from base + opt_map = opts.get_options_map() + self.assertEqual(opt_map['graph-read-consistency'], ConsistencyLevel.value_to_name[read_cl].encode()) + self.assertEqual(opt_map['graph-write-consistency'], ConsistencyLevel.value_to_name[write_cl].encode()) + + # empty by default + new_opts = GraphOptions() + opt_map = new_opts.get_options_map() + self.assertNotIn('graph-read-consistency', opt_map) + self.assertNotIn('graph-write-consistency', opt_map) + + # set from other + opt_map = new_opts.get_options_map(opts) + self.assertEqual(opt_map['graph-read-consistency'], ConsistencyLevel.value_to_name[read_cl].encode()) + self.assertEqual(opt_map['graph-write-consistency'], ConsistencyLevel.value_to_name[write_cl].encode()) + + def test_graph_source_convenience_attributes(self): + opts = GraphOptions() + self.assertEqual(opts.graph_source, b'g') + self.assertFalse(opts.is_analytics_source) + self.assertTrue(opts.is_graph_source) + self.assertFalse(opts.is_default_source) + + opts.set_source_default() + self.assertIsNotNone(opts.graph_source) + self.assertFalse(opts.is_analytics_source) + self.assertFalse(opts.is_graph_source) + self.assertTrue(opts.is_default_source) + + opts.set_source_analytics() + self.assertIsNotNone(opts.graph_source) + self.assertTrue(opts.is_analytics_source) + self.assertFalse(opts.is_graph_source) + self.assertFalse(opts.is_default_source) + + opts.set_source_graph() + self.assertIsNotNone(opts.graph_source) + self.assertFalse(opts.is_analytics_source) + self.assertTrue(opts.is_graph_source) + self.assertFalse(opts.is_default_source) + +class GraphStatementTests(unittest.TestCase): + + def test_init(self): + # just make sure Statement attributes are accepted + kwargs = {'query_string': object(), + 'retry_policy': RetryPolicy(), + 'consistency_level': object(), + 'fetch_size': object(), + 'keyspace': object(), + 'custom_payload': object()} + statement = SimpleGraphStatement(**kwargs) + for k, v in kwargs.items(): + self.assertIs(getattr(statement, k), v) + + # but not a bogus parameter + kwargs['bogus'] = object() + self.assertRaises(TypeError, SimpleGraphStatement, **kwargs) + + +class GraphRowFactoryTests(unittest.TestCase): + + def test_object_row_factory(self): + col_names = [] # unused + rows = [object() for _ in range(10)] + self.assertEqual(single_object_row_factory(col_names, ((o,) for o in rows)), rows) + + def test_graph_result_row_factory(self): + col_names = [] # unused + rows = [json.dumps({'result': i}) for i in range(10)] + results = graph_result_row_factory(col_names, ((o,) for o in rows)) + for i, res in enumerate(results): + self.assertIsInstance(res, Result) + self.assertEqual(res.value, i) diff --git a/tests/unit/advanced/test_insights.py b/tests/unit/advanced/test_insights.py new file mode 100644 index 0000000000..e6be6fc3d1 --- /dev/null +++ b/tests/unit/advanced/test_insights.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import logging +import sys +from unittest.mock import sentinel + +from cassandra import ConsistencyLevel +from cassandra.cluster import ( + ExecutionProfile, GraphExecutionProfile, ProfileManager, + GraphAnalyticsExecutionProfile, + EXEC_PROFILE_DEFAULT, EXEC_PROFILE_GRAPH_DEFAULT, + EXEC_PROFILE_GRAPH_ANALYTICS_DEFAULT, + EXEC_PROFILE_GRAPH_SYSTEM_DEFAULT +) +from cassandra.datastax.graph.query import GraphOptions +from cassandra.datastax.insights.registry import insights_registry +from cassandra.datastax.insights.serializers import initialize_registry +from cassandra.datastax.insights.util import namespace +from cassandra.policies import ( + RoundRobinPolicy, + LoadBalancingPolicy, + DCAwareRoundRobinPolicy, + TokenAwarePolicy, + WhiteListRoundRobinPolicy, + HostFilterPolicy, + ConstantReconnectionPolicy, + ExponentialReconnectionPolicy, + RetryPolicy, + SpeculativeExecutionPolicy, + ConstantSpeculativeExecutionPolicy, + WrapperPolicy +) + + +log = logging.getLogger(__name__) + +initialize_registry(insights_registry) + + +class TestGetConfig(unittest.TestCase): + + def test_invalid_object(self): + class NoConfAsDict(object): + pass + + obj = NoConfAsDict() + + ns = 'tests.unit.advanced.test_insights' + if sys.version_info > (3,): + ns += '.TestGetConfig.test_invalid_object.' + + # no default + # ... as a policy + self.assertEqual(insights_registry.serialize(obj, policy=True), + {'type': 'NoConfAsDict', + 'namespace': ns, + 'options': {}}) + # ... not as a policy (default) + self.assertEqual(insights_registry.serialize(obj), + {'type': 'NoConfAsDict', + 'namespace': ns, + }) + # with default + self.assertIs(insights_registry.serialize(obj, default=sentinel.attr_err_default), + sentinel.attr_err_default) + + def test_successful_return(self): + + class SuperclassSentinel(object): + pass + + class SubclassSentinel(SuperclassSentinel): + pass + + @insights_registry.register_serializer_for(SuperclassSentinel) + def superclass_sentinel_serializer(obj): + return sentinel.serialized_superclass + + self.assertIs(insights_registry.serialize(SuperclassSentinel()), + sentinel.serialized_superclass) + self.assertIs(insights_registry.serialize(SubclassSentinel()), + sentinel.serialized_superclass) + + # with default -- same behavior + self.assertIs(insights_registry.serialize(SubclassSentinel(), default=object()), + sentinel.serialized_superclass) + +class TestConfigAsDict(unittest.TestCase): + + # graph/query.py + def test_graph_options(self): + self.maxDiff = None + + go = GraphOptions(graph_name='name_for_test', + graph_source='source_for_test', + graph_language='lang_for_test', + graph_protocol='protocol_for_test', + graph_read_consistency_level=ConsistencyLevel.ANY, + graph_write_consistency_level=ConsistencyLevel.ONE, + graph_invalid_option='invalid') + + log.debug(go._graph_options) + + self.assertEqual( + insights_registry.serialize(go), + {'source': 'source_for_test', + 'language': 'lang_for_test', + 'graphProtocol': 'protocol_for_test', + # no graph_invalid_option + } + ) + + # cluster.py + def test_execution_profile(self): + self.maxDiff = None + self.assertEqual( + insights_registry.serialize(ExecutionProfile()), + {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}, + 'readTimeout': 10.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'RetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': None + } + ) + + def test_graph_execution_profile(self): + self.maxDiff = None + self.assertEqual( + insights_registry.serialize(GraphExecutionProfile()), + {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}, + 'readTimeout': 30.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': {'graphProtocol': None, + 'language': 'gremlin-groovy', + 'source': 'g'}, + } + ) + + def test_graph_analytics_execution_profile(self): + self.maxDiff = None + self.assertEqual( + insights_registry.serialize(GraphAnalyticsExecutionProfile()), + {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}}, + 'type': 'DefaultLoadBalancingPolicy'}, + 'readTimeout': 604800.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': {'graphProtocol': None, + 'language': 'gremlin-groovy', + 'source': 'a'}, + } + ) + + # policies.py + def test_DC_aware_round_robin_policy(self): + self.assertEqual( + insights_registry.serialize(DCAwareRoundRobinPolicy()), + {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'} + ) + self.assertEqual( + insights_registry.serialize(DCAwareRoundRobinPolicy(local_dc='fake_local_dc', + used_hosts_per_remote_dc=15)), + {'namespace': 'cassandra.policies', + 'options': {'local_dc': 'fake_local_dc', 'used_hosts_per_remote_dc': 15}, + 'type': 'DCAwareRoundRobinPolicy'} + ) + + def test_token_aware_policy(self): + self.assertEqual( + insights_registry.serialize(TokenAwarePolicy(child_policy=LoadBalancingPolicy())), + {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'} + ) + + def test_whitelist_round_robin_policy(self): + self.assertEqual( + insights_registry.serialize(WhiteListRoundRobinPolicy(['127.0.0.3'])), + {'namespace': 'cassandra.policies', + 'options': {'allowed_hosts': ('127.0.0.3',)}, + 'type': 'WhiteListRoundRobinPolicy'} + ) + + def test_host_filter_policy(self): + def my_predicate(s): + return False + + self.assertEqual( + insights_registry.serialize(HostFilterPolicy(LoadBalancingPolicy(), my_predicate)), + {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'}, + 'predicate': 'my_predicate'}, + 'type': 'HostFilterPolicy'} + ) + + def test_constant_reconnection_policy(self): + self.assertEqual( + insights_registry.serialize(ConstantReconnectionPolicy(3, 200)), + {'type': 'ConstantReconnectionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'delay': 3, 'max_attempts': 200} + } + ) + + def test_exponential_reconnection_policy(self): + self.assertEqual( + insights_registry.serialize(ExponentialReconnectionPolicy(4, 100, 10)), + {'type': 'ExponentialReconnectionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10} + } + ) + + def test_retry_policy(self): + self.assertEqual( + insights_registry.serialize(RetryPolicy()), + {'type': 'RetryPolicy', + 'namespace': 'cassandra.policies', + 'options': {} + } + ) + + def test_spec_exec_policy(self): + self.assertEqual( + insights_registry.serialize(SpeculativeExecutionPolicy()), + {'type': 'SpeculativeExecutionPolicy', + 'namespace': 'cassandra.policies', + 'options': {} + } + ) + + def test_constant_spec_exec_policy(self): + self.assertEqual( + insights_registry.serialize(ConstantSpeculativeExecutionPolicy(100, 101)), + {'type': 'ConstantSpeculativeExecutionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'delay': 100, + 'max_attempts': 101} + } + ) + + def test_wrapper_policy(self): + self.assertEqual( + insights_registry.serialize(WrapperPolicy(LoadBalancingPolicy())), + {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'} + }, + 'type': 'WrapperPolicy'} + ) diff --git a/tests/unit/advanced/test_metadata.py b/tests/unit/advanced/test_metadata.py new file mode 100644 index 0000000000..052ad3f465 --- /dev/null +++ b/tests/unit/advanced/test_metadata.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.metadata import ( + KeyspaceMetadata, TableMetadataDSE68, + VertexMetadata, EdgeMetadata +) + + +class GraphMetadataToCQLTests(unittest.TestCase): + + def _create_edge_metadata(self, partition_keys=['pk1'], clustering_keys=['c1']): + return EdgeMetadata( + 'keyspace', 'table', 'label', 'from_table', 'from_label', + partition_keys, clustering_keys, 'to_table', 'to_label', + partition_keys, clustering_keys) + + def _create_vertex_metadata(self, label_name='label'): + return VertexMetadata('keyspace', 'table', label_name) + + def _create_keyspace_metadata(self, graph_engine): + return KeyspaceMetadata( + 'keyspace', True, 'org.apache.cassandra.locator.SimpleStrategy', + {'replication_factor': 1}, graph_engine=graph_engine) + + def _create_table_metadata(self, with_vertex=False, with_edge=False): + tm = TableMetadataDSE68('keyspace', 'table') + if with_vertex: + tm.vertex = self._create_vertex_metadata() if with_vertex is True else with_vertex + elif with_edge: + tm.edge = self._create_edge_metadata() if with_edge is True else with_edge + + return tm + + def test_keyspace_no_graph_engine(self): + km = self._create_keyspace_metadata(None) + self.assertEqual(km.graph_engine, None) + self.assertNotIn( + "graph_engine", + km.as_cql_query() + ) + + def test_keyspace_with_graph_engine(self): + graph_engine = 'Core' + km = self._create_keyspace_metadata(graph_engine) + self.assertEqual(km.graph_engine, graph_engine) + cql = km.as_cql_query() + self.assertIn( + "graph_engine", + cql + ) + self.assertIn( + "Core", + cql + ) + + def test_table_no_vertex_or_edge(self): + tm = self._create_table_metadata() + self.assertIsNone(tm.vertex) + self.assertIsNone(tm.edge) + cql = tm.as_cql_query() + self.assertNotIn("VERTEX LABEL", cql) + self.assertNotIn("EDGE LABEL", cql) + + def test_table_with_vertex(self): + tm = self._create_table_metadata(with_vertex=True) + self.assertIsInstance(tm.vertex, VertexMetadata) + self.assertIsNone(tm.edge) + cql = tm.as_cql_query() + self.assertIn("VERTEX LABEL", cql) + self.assertNotIn("EDGE LABEL", cql) + + def test_table_with_edge(self): + tm = self._create_table_metadata(with_edge=True) + self.assertIsNone(tm.vertex) + self.assertIsInstance(tm.edge, EdgeMetadata) + cql = tm.as_cql_query() + self.assertNotIn("VERTEX LABEL", cql) + self.assertIn("EDGE LABEL", cql) + self.assertIn("FROM from_label", cql) + self.assertIn("TO to_label", cql) + + def test_vertex_with_label(self): + tm = self. _create_table_metadata(with_vertex=True) + self.assertTrue(tm.as_cql_query().endswith('VERTEX LABEL label')) + + def test_edge_single_partition_key_and_clustering_key(self): + tm = self._create_table_metadata(with_edge=True) + self.assertIn( + 'FROM from_label(pk1, c1)', + tm.as_cql_query() + ) + + def test_edge_multiple_partition_keys(self): + edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2']) + tm = self. _create_table_metadata(with_edge=edge) + self.assertIn( + 'FROM from_label((pk1, pk2), ', + tm.as_cql_query() + ) + + def test_edge_no_clustering_keys(self): + edge = self._create_edge_metadata(clustering_keys=[]) + tm = self. _create_table_metadata(with_edge=edge) + self.assertIn( + 'FROM from_label(pk1) ', + tm.as_cql_query() + ) + + def test_edge_multiple_clustering_keys(self): + edge = self._create_edge_metadata(clustering_keys=['c1', 'c2']) + tm = self. _create_table_metadata(with_edge=edge) + self.assertIn( + 'FROM from_label(pk1, c1, c2) ', + tm.as_cql_query() + ) + + def test_edge_multiple_partition_and_clustering_keys(self): + edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2'], + clustering_keys=['c1', 'c2']) + tm = self. _create_table_metadata(with_edge=edge) + self.assertIn( + 'FROM from_label((pk1, pk2), c1, c2) ', + tm.as_cql_query() + ) diff --git a/tests/unit/advanced/test_policies.py b/tests/unit/advanced/test_policies.py new file mode 100644 index 0000000000..406263f42b --- /dev/null +++ b/tests/unit/advanced/test_policies.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest.mock import Mock + +from cassandra.pool import Host +from cassandra.policies import RoundRobinPolicy + +from cassandra.policies import DSELoadBalancingPolicy + + +class ClusterMetaMock(object): + def __init__(self, hosts=None): + self.hosts = hosts or {} + + def get_host(self, addr): + return self.hosts.get(addr) + + +class DSELoadBalancingPolicyTest(unittest.TestCase): + + def test_no_target(self): + node_count = 4 + hosts = list(range(node_count)) + policy = DSELoadBalancingPolicy(RoundRobinPolicy()) + policy.populate(Mock(metadata=ClusterMetaMock()), hosts) + for _ in range(node_count): + query_plan = list(policy.make_query_plan(None, Mock(target_host=None))) + self.assertEqual(sorted(query_plan), hosts) + + def test_status_updates(self): + node_count = 4 + hosts = list(range(node_count)) + policy = DSELoadBalancingPolicy(RoundRobinPolicy()) + policy.populate(Mock(metadata=ClusterMetaMock()), hosts) + policy.on_down(0) + policy.on_remove(1) + policy.on_up(4) + policy.on_add(5) + query_plan = list(policy.make_query_plan()) + self.assertEqual(sorted(query_plan), [2, 3, 4, 5]) + + def test_no_live_nodes(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + for i in range(4): + policy.on_down(i) + + query_plan = list(policy.make_query_plan()) + self.assertEqual(query_plan, []) + + def test_target_no_host(self): + node_count = 4 + hosts = list(range(node_count)) + policy = DSELoadBalancingPolicy(RoundRobinPolicy()) + policy.populate(Mock(metadata=ClusterMetaMock()), hosts) + query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) + self.assertEqual(sorted(query_plan), hosts) + + def test_target_host_down(self): + node_count = 4 + hosts = [Host(i, Mock()) for i in range(node_count)] + target_host = hosts[1] + + policy = DSELoadBalancingPolicy(RoundRobinPolicy()) + policy.populate(Mock(metadata=ClusterMetaMock({'127.0.0.1': target_host})), hosts) + query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) + self.assertEqual(sorted(query_plan), hosts) + + target_host.is_up = False + policy.on_down(target_host) + query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) + self.assertNotIn(target_host, query_plan) + + def test_target_host_nominal(self): + node_count = 4 + hosts = [Host(i, Mock()) for i in range(node_count)] + target_host = hosts[1] + target_host.is_up = True + + policy = DSELoadBalancingPolicy(RoundRobinPolicy()) + policy.populate(Mock(metadata=ClusterMetaMock({'127.0.0.1': target_host})), hosts) + for _ in range(10): + query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) + self.assertEqual(sorted(query_plan), hosts) + self.assertEqual(query_plan[0], target_host) diff --git a/tests/unit/column_encryption/test_policies.py b/tests/unit/column_encryption/test_policies.py new file mode 100644 index 0000000000..f78701aa2f --- /dev/null +++ b/tests/unit/column_encryption/test_policies.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from cassandra.policies import ColDesc +from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, \ + AES256_BLOCK_SIZE_BYTES, AES256_KEY_SIZE_BYTES + +class AES256ColumnEncryptionPolicyTest(unittest.TestCase): + + def _random_block(self): + return os.urandom(AES256_BLOCK_SIZE_BYTES) + + def _random_key(self): + return os.urandom(AES256_KEY_SIZE_BYTES) + + def _test_round_trip(self, bytes): + coldesc = ColDesc('ks1','table1','col1') + policy = AES256ColumnEncryptionPolicy() + policy.add_column(coldesc, self._random_key(), "blob") + encrypted_bytes = policy.encrypt(coldesc, bytes) + self.assertEqual(bytes, policy.decrypt(coldesc, encrypted_bytes)) + + def test_no_padding_necessary(self): + self._test_round_trip(self._random_block()) + + def test_some_padding_required(self): + for byte_size in range(1,AES256_BLOCK_SIZE_BYTES - 1): + bytes = os.urandom(byte_size) + self._test_round_trip(bytes) + for byte_size in range(AES256_BLOCK_SIZE_BYTES + 1,(2 * AES256_BLOCK_SIZE_BYTES) - 1): + bytes = os.urandom(byte_size) + self._test_round_trip(bytes) + + def test_add_column_invalid_key_size_raises(self): + coldesc = ColDesc('ks1','table1','col1') + policy = AES256ColumnEncryptionPolicy() + for key_size in range(1,AES256_KEY_SIZE_BYTES - 1): + with self.assertRaises(ValueError): + policy.add_column(coldesc, os.urandom(key_size), "blob") + for key_size in range(AES256_KEY_SIZE_BYTES + 1,(2 * AES256_KEY_SIZE_BYTES) - 1): + with self.assertRaises(ValueError): + policy.add_column(coldesc, os.urandom(key_size), "blob") + + def test_add_column_invalid_iv_size_raises(self): + def test_iv_size(iv_size): + policy = AES256ColumnEncryptionPolicy(iv = os.urandom(iv_size)) + policy.add_column(coldesc, os.urandom(AES256_KEY_SIZE_BYTES), "blob") + policy.encrypt(coldesc, os.urandom(128)) + + coldesc = ColDesc('ks1','table1','col1') + for iv_size in range(1,AES256_BLOCK_SIZE_BYTES - 1): + with self.assertRaises(ValueError): + test_iv_size(iv_size) + for iv_size in range(AES256_BLOCK_SIZE_BYTES + 1,(2 * AES256_BLOCK_SIZE_BYTES) - 1): + with self.assertRaises(ValueError): + test_iv_size(iv_size) + + # Finally, confirm that the expected IV size has no issue + test_iv_size(AES256_BLOCK_SIZE_BYTES) + + def test_add_column_null_coldesc_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + policy.add_column(None, self._random_block(), "blob") + + def test_add_column_null_key_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, None, "blob") + + def test_add_column_null_type_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_block(), None) + + def test_add_column_unknown_type_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_block(), "foobar") + + def test_encode_and_encrypt_null_coldesc_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_key(), "blob") + policy.encode_and_encrypt(None, self._random_block()) + + def test_encode_and_encrypt_null_obj_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_key(), "blob") + policy.encode_and_encrypt(coldesc, None) + + def test_encode_and_encrypt_unknown_coldesc_raises(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_key(), "blob") + policy.encode_and_encrypt(ColDesc('ks2','table2','col2'), self._random_block()) + + def test_contains_column(self): + coldesc = ColDesc('ks1','table1','col1') + policy = AES256ColumnEncryptionPolicy() + policy.add_column(coldesc, self._random_key(), "blob") + self.assertTrue(policy.contains_column(coldesc)) + self.assertFalse(policy.contains_column(ColDesc('ks2','table1','col1'))) + self.assertFalse(policy.contains_column(ColDesc('ks1','table2','col1'))) + self.assertFalse(policy.contains_column(ColDesc('ks1','table1','col2'))) + self.assertFalse(policy.contains_column(ColDesc('ks2','table2','col2'))) + + def test_encrypt_unknown_column(self): + with self.assertRaises(ValueError): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_key(), "blob") + policy.encrypt(ColDesc('ks2','table2','col2'), self._random_block()) + + def test_decrypt_unknown_column(self): + policy = AES256ColumnEncryptionPolicy() + coldesc = ColDesc('ks1','table1','col1') + policy.add_column(coldesc, self._random_key(), "blob") + encrypted_bytes = policy.encrypt(coldesc, self._random_block()) + with self.assertRaises(ValueError): + policy.decrypt(ColDesc('ks2','table2','col2'), encrypted_bytes) + + def test_cache_info(self): + # Exclude any interference from tests above + AES256ColumnEncryptionPolicy._build_cipher.cache_clear() + + coldesc1 = ColDesc('ks1','table1','col1') + coldesc2 = ColDesc('ks2','table2','col2') + coldesc3 = ColDesc('ks3','table3','col3') + policy = AES256ColumnEncryptionPolicy() + for coldesc in [coldesc1, coldesc2, coldesc3]: + policy.add_column(coldesc, self._random_key(), "blob") + + # First run for this coldesc should be a miss, everything else should be a cache hit + for _ in range(10): + policy.encrypt(coldesc1, self._random_block()) + cache_info = policy.cache_info() + self.assertEqual(cache_info.hits, 9) + self.assertEqual(cache_info.misses, 1) + self.assertEqual(cache_info.maxsize, 128) + + # Important note: we're measuring the size of the cache of ciphers, NOT stored + # keys. We won't have a cipher here until we actually encrypt something + self.assertEqual(cache_info.currsize, 1) + policy.encrypt(coldesc2, self._random_block()) + self.assertEqual(policy.cache_info().currsize, 2) + policy.encrypt(coldesc3, self._random_block()) + self.assertEqual(policy.cache_info().currsize, 3) diff --git a/tests/unit/cqlengine/__init__.py b/tests/unit/cqlengine/__init__.py new file mode 100644 index 0000000000..588a655d98 --- /dev/null +++ b/tests/unit/cqlengine/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cqlengine/test_columns.py b/tests/unit/cqlengine/test_columns.py new file mode 100644 index 0000000000..4d264df07c --- /dev/null +++ b/tests/unit/cqlengine/test_columns.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.cqlengine.columns import Column + + +class ColumnTest(unittest.TestCase): + + def test_comparisons(self): + c0 = Column() + c1 = Column() + self.assertEqual(c1.position - c0.position, 1) + + # __ne__ + self.assertNotEqual(c0, c1) + self.assertNotEqual(c0, object()) + + # __eq__ + self.assertEqual(c0, c0) + self.assertFalse(c0 == object()) + + # __lt__ + self.assertLess(c0, c1) + try: + c0 < object() # this raises for Python 3 + except TypeError: + pass + + # __le__ + self.assertLessEqual(c0, c1) + self.assertLessEqual(c0, c0) + try: + c0 <= object() # this raises for Python 3 + except TypeError: + pass + + # __gt__ + self.assertGreater(c1, c0) + try: + c1 > object() # this raises for Python 3 + except TypeError: + pass + + # __ge__ + self.assertGreaterEqual(c1, c0) + self.assertGreaterEqual(c1, c1) + try: + c1 >= object() # this raises for Python 3 + except TypeError: + pass + + def test_hash(self): + c0 = Column() + self.assertEqual(id(c0), c0.__hash__()) + diff --git a/tests/unit/cqlengine/test_connection.py b/tests/unit/cqlengine/test_connection.py new file mode 100644 index 0000000000..dd7586aff0 --- /dev/null +++ b/tests/unit/cqlengine/test_connection.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from cassandra.cluster import _ConfigMode +from cassandra.cqlengine import connection +from cassandra.query import dict_factory + + +class ConnectionTest(unittest.TestCase): + + no_registered_connection_msg = "doesn't exist in the registry" + + def setUp(self): + super(ConnectionTest, self).setUp() + self.assertFalse( + connection._connections, + 'Test precondition not met: connections are registered: {cs}'.format(cs=connection._connections) + ) + + def test_set_session_without_existing_connection(self): + """ + Users can set the default session without having a default connection set. + """ + mock_cluster = Mock( + _config_mode=_ConfigMode.LEGACY, + ) + mock_session = Mock( + row_factory=dict_factory, + encoder=Mock(mapping={}), + cluster=mock_cluster, + ) + connection.set_session(mock_session) + + def test_get_session_fails_without_existing_connection(self): + """ + Users can't get the default session without having a default connection set. + """ + with self.assertRaisesRegex(connection.CQLEngineException, self.no_registered_connection_msg): + connection.get_session(connection=None) + + def test_get_cluster_fails_without_existing_connection(self): + """ + Users can't get the default cluster without having a default connection set. + """ + with self.assertRaisesRegex(connection.CQLEngineException, self.no_registered_connection_msg): + connection.get_cluster(connection=None) diff --git a/tests/unit/cqlengine/test_udt.py b/tests/unit/cqlengine/test_udt.py new file mode 100644 index 0000000000..de87bf3833 --- /dev/null +++ b/tests/unit/cqlengine/test_udt.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.cqlengine import columns +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.usertype import UserType + + +class UDTTest(unittest.TestCase): + + def test_initialization_without_existing_connection(self): + """ + Test that users can define models with UDTs without initializing + connections. + + Written to reproduce PYTHON-649. + """ + + class Value(UserType): + t = columns.Text() + + class DummyUDT(Model): + __keyspace__ = 'ks' + primary_key = columns.Integer(primary_key=True) + value = columns.UserDefinedType(Value) diff --git a/tests/unit/cython/__init__.py b/tests/unit/cython/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/unit/cython/__init__.py +++ b/tests/unit/cython/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/cython/bytesio_testhelper.pyx b/tests/unit/cython/bytesio_testhelper.pyx index e86fdb73c2..dcb8c4a4de 100644 --- a/tests/unit/cython/bytesio_testhelper.pyx +++ b/tests/unit/cython/bytesio_testhelper.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/cython/test_bytesio.py b/tests/unit/cython/test_bytesio.py index 0c2ae7bfa8..08ca284ff3 100644 --- a/tests/unit/cython/test_bytesio.py +++ b/tests/unit/cython/test_bytesio.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,10 +17,7 @@ from tests.unit.cython.utils import cyimport, cythontest bytesio_testhelper = cyimport('tests.unit.cython.bytesio_testhelper') -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest class BytesIOTest(unittest.TestCase): diff --git a/tests/unit/cython/test_types.py b/tests/unit/cython/test_types.py index d9f3a746e1..4ae18639f6 100644 --- a/tests/unit/cython/test_types.py +++ b/tests/unit/cython/test_types.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,10 +17,7 @@ from tests.unit.cython.utils import cyimport, cythontest types_testhelper = cyimport('tests.unit.cython.types_testhelper') -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest class TypesTest(unittest.TestCase): diff --git a/tests/unit/cython/test_utils.py b/tests/unit/cython/test_utils.py index 209056f645..e43ae343f4 100644 --- a/tests/unit/cython/test_utils.py +++ b/tests/unit/cython/test_utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -15,10 +17,7 @@ from tests.unit.cython.utils import cyimport, cythontest utils_testhelper = cyimport('tests.unit.cython.utils_testhelper') -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest class UtilsTest(unittest.TestCase): @@ -26,4 +25,4 @@ class UtilsTest(unittest.TestCase): @cythontest def test_datetime_from_timestamp(self): - utils_testhelper.test_datetime_from_timestamp(self.assertEqual) \ No newline at end of file + utils_testhelper.test_datetime_from_timestamp(self.assertEqual) diff --git a/tests/unit/cython/types_testhelper.pyx b/tests/unit/cython/types_testhelper.pyx index 3cd60c550f..a9252df7ee 100644 --- a/tests/unit/cython/types_testhelper.pyx +++ b/tests/unit/cython/types_testhelper.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -38,7 +40,7 @@ def test_datetype(assert_equal): cdef BytesIOReader reader cdef Buffer buf - dt = datetime.datetime.utcfromtimestamp(timestamp) + dt = datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) bytes = io.BytesIO() write_value(bytes, DateType.serialize(dt, 0)) @@ -52,7 +54,7 @@ def test_datetype(assert_equal): # deserialize # epoc expected = 0 - assert_equal(deserialize(expected), datetime.datetime.utcfromtimestamp(expected)) + assert_equal(deserialize(expected), datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None)) # beyond 32b expected = 2 ** 33 diff --git a/tests/unit/cython/utils.py b/tests/unit/cython/utils.py index 6dbeb15c06..de348afffa 100644 --- a/tests/unit/cython/utils.py +++ b/tests/unit/cython/utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +15,12 @@ # limitations under the License. from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY - try: - import unittest2 as unittest + from tests import VERIFY_CYTHON except ImportError: - import unittest # noqa + VERIFY_CYTHON = False + +import unittest def cyimport(import_path): """ @@ -34,6 +37,6 @@ def cyimport(import_path): # @cythontest # def test_something(self): ... -cythontest = unittest.skipUnless(HAVE_CYTHON, 'Cython is not available') +cythontest = unittest.skipUnless((HAVE_CYTHON or VERIFY_CYTHON) or VERIFY_CYTHON, 'Cython is not available') notcython = unittest.skipIf(HAVE_CYTHON, 'Cython not supported') -numpytest = unittest.skipUnless(HAVE_CYTHON and HAVE_NUMPY, 'NumPy is not available') +numpytest = unittest.skipUnless((HAVE_CYTHON and HAVE_NUMPY) or VERIFY_CYTHON, 'NumPy is not available') diff --git a/tests/unit/cython/utils_testhelper.pyx b/tests/unit/cython/utils_testhelper.pyx index 32816d3a31..8a8294d9c7 100644 --- a/tests/unit/cython/utils_testhelper.pyx +++ b/tests/unit/cython/utils_testhelper.pyx @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py index 87fc3685e0..588a655d98 100644 --- a/tests/unit/io/__init__.py +++ b/tests/unit/io/__init__.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/io/eventlet_utils.py b/tests/unit/io/eventlet_utils.py index e06d3f777f..ef3e633ac7 100644 --- a/tests/unit/io/eventlet_utils.py +++ b/tests/unit/io/eventlet_utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +18,21 @@ import os import select import socket -import thread -import Queue +try: + import thread + import Queue + import __builtin__ + #For python3 compatibility +except ImportError: + import _thread as thread + import queue as Queue + import builtins as __builtin__ + import threading -import __builtin__ import ssl import time - +import eventlet +from imp import reload def eventlet_un_patch_all(): """ @@ -34,4 +44,7 @@ def eventlet_un_patch_all(): for to_unpatch in modules_to_unpatch: reload(to_unpatch) +def restore_saved_module(module): + reload(module) + del eventlet.patcher.already_patched[module.__name__] diff --git a/tests/unit/io/gevent_utils.py b/tests/unit/io/gevent_utils.py index 12aab8d2f1..b458d13170 100644 --- a/tests/unit/io/gevent_utils.py +++ b/tests/unit/io/gevent_utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/tests/unit/io/test_asyncioreactor.py b/tests/unit/io/test_asyncioreactor.py new file mode 100644 index 0000000000..65708d41dc --- /dev/null +++ b/tests/unit/io/test_asyncioreactor.py @@ -0,0 +1,77 @@ +AsyncioConnection, ASYNCIO_AVAILABLE = None, False +try: + from cassandra.io.asyncioreactor import AsyncioConnection + import asynctest + ASYNCIO_AVAILABLE = True +except (ImportError, SyntaxError): + AsyncioConnection = None + ASYNCIO_AVAILABLE = False + +from tests import is_monkey_patched, connection_class +from tests.unit.io.utils import TimerCallback, TimerTestMixin + +from unittest.mock import patch + +import unittest +import time + +skip_me = (is_monkey_patched() or + (not ASYNCIO_AVAILABLE) or + (connection_class is not AsyncioConnection)) + + +@unittest.skipIf(is_monkey_patched(), 'runtime is monkey patched for another reactor') +@unittest.skipIf(connection_class is not AsyncioConnection, + 'not running asyncio tests; current connection_class is {}'.format(connection_class)) +@unittest.skipUnless(ASYNCIO_AVAILABLE, "asyncio is not available for this runtime") +class AsyncioTimerTests(TimerTestMixin, unittest.TestCase): + + @classmethod + def setUpClass(cls): + if skip_me: + return + cls.connection_class = AsyncioConnection + AsyncioConnection.initialize_reactor() + + @classmethod + def tearDownClass(cls): + if skip_me: + return + if ASYNCIO_AVAILABLE and AsyncioConnection._loop: + AsyncioConnection._loop.stop() + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + raise RuntimeError('no TimerManager for AsyncioConnection') + + def setUp(self): + if skip_me: + return + socket_patcher = patch('socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() + + old_selector = AsyncioConnection._loop._selector + AsyncioConnection._loop._selector = asynctest.TestSelector() + + def reset_selector(): + AsyncioConnection._loop._selector = old_selector + + self.addCleanup(reset_selector) + + super(AsyncioTimerTests, self).setUp() + + def test_timer_cancellation(self): + # Various lists for tracking callback stage + timeout = .1 + callback = TimerCallback(timeout) + timer = self.create_timer(timeout, callback.invoke) + timer.cancel() + # Release context allow for timer thread to run. + time.sleep(.2) + # Assert that the cancellation was honored + self.assertFalse(callback.was_invoked()) diff --git a/tests/unit/io/test_asyncorereactor.py b/tests/unit/io/test_asyncorereactor.py index ab5bd64091..b37df83bf6 100644 --- a/tests/unit/io/test_asyncorereactor.py +++ b/tests/unit/io/test_asyncorereactor.py @@ -1,336 +1,94 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys -import six +import platform +import socket +import unittest + +from unittest.mock import patch +from packaging.version import Version +from cassandra import DependencyException try: - import unittest2 as unittest -except ImportError: - import unittest # noqa + import cassandra.io.asyncorereactor as asyncorereactor + from cassandra.io.asyncorereactor import AsyncoreConnection +except DependencyException: + AsyncoreConnection = None -import errno -import math -import time -from mock import patch, Mock -import os -from six import BytesIO -import socket -from socket import error as socket_error -from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, - ConnectionException, ProtocolError,Timer) -from cassandra.io.asyncorereactor import AsyncoreConnection -from cassandra.protocol import (write_stringmultimap, write_int, write_string, - SupportedMessage, ReadyMessage, ServerError) -from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from tests import is_monkey_patched -from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback +from tests.unit.io.utils import ReactorTestMixin, TimerTestMixin, noop_if_monkey_patched -class AsyncoreConnectionTest(unittest.TestCase): +class AsyncorePatcher(unittest.TestCase): @classmethod + @noop_if_monkey_patched def setUpClass(cls): if is_monkey_patched(): return AsyncoreConnection.initialize_reactor() - cls.socket_patcher = patch('socket.socket', spec=socket.socket) - cls.mock_socket = cls.socket_patcher.start() - cls.mock_socket().connect_ex.return_value = 0 - cls.mock_socket().getsockopt.return_value = 0 - cls.mock_socket().fileno.return_value = 100 - - AsyncoreConnection.add_channel = lambda *args, **kwargs: None - - @classmethod - def tearDownClass(cls): - if is_monkey_patched(): - return - cls.socket_patcher.stop() - - def setUp(self): - if is_monkey_patched(): - raise unittest.SkipTest("Can't test asyncore with monkey patching") - - def make_connection(self): - c = AsyncoreConnection('1.2.3.4', cql_version='3.0.1') - c.socket = Mock() - c.socket.send.side_effect = lambda x: len(x) - return c - - def make_header_prefix(self, message_class, version=2, stream_id=0): - return six.binary_type().join(map(uint8_pack, [ - 0xff & (HEADER_DIRECTION_TO_CLIENT | version), - 0, # flags (compression) - stream_id, - message_class.opcode # opcode - ])) - - def make_options_body(self): - options_buf = BytesIO() - write_stringmultimap(options_buf, { - 'CQL_VERSION': ['3.0.1'], - 'COMPRESSION': [] - }) - return options_buf.getvalue() - - def make_error_body(self, code, msg): - buf = BytesIO() - write_int(buf, code) - write_string(buf, msg) - return buf.getvalue() - - def make_msg(self, header, body=six.binary_type()): - return header + uint32_pack(len(body)) + body - - def test_successful_connection(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write() - - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - c.socket.recv.return_value = self.make_msg(header, options) - c.handle_read() - - # let it write out a StartupMessage - c.handle_write() - - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c.socket.recv.return_value = self.make_msg(header) - c.handle_read() - - self.assertTrue(c.connected_event.is_set()) - return c - - def test_egain_on_buffer_size(self, *args): - # get a connection that's already fully started - c = self.test_successful_connection() - - header = six.b('\x00\x00\x00\x00') + int32_pack(20000) - responses = [ - header + (six.b('a') * (4096 - len(header))), - six.b('a') * 4096, - socket_error(errno.EAGAIN), - six.b('a') * 100, - socket_error(errno.EAGAIN)] - - def side_effect(*args): - response = responses.pop(0) - if isinstance(response, socket_error): - raise response - else: - return response - - c.socket.recv.side_effect = side_effect - c.handle_read() - self.assertEqual(c._current_frame.end_pos, 20000 + len(header)) - # the EAGAIN prevents it from reading the last 100 bytes - c._iobuf.seek(0, os.SEEK_END) - pos = c._iobuf.tell() - self.assertEqual(pos, 4096 + 4096) - - # now tell it to read the last 100 bytes - c.handle_read() - c._iobuf.seek(0, os.SEEK_END) - pos = c._iobuf.tell() - self.assertEqual(pos, 4096 + 4096 + 100) - - def test_protocol_error(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write() - - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage, version=0xa4) - options = self.make_options_body() - c.socket.recv.return_value = self.make_msg(header, options) - c.handle_read() - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertTrue(c.connected_event.is_set()) - self.assertIsInstance(c.last_error, ProtocolError) - - def test_error_message_on_startup(self, *args): - c = self.make_connection() - # let it write the OptionsMessage - c.handle_write() + socket_patcher = patch('socket.socket', spec=socket.socket) + channel_patcher = patch( + 'cassandra.io.asyncorereactor.AsyncoreConnection.add_channel', + new=(lambda *args, **kwargs: None) + ) - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - c.socket.recv.return_value = self.make_msg(header, options) - c.handle_read() + cls.mock_socket = socket_patcher.start() + cls.mock_socket.connect_ex.return_value = 0 + cls.mock_socket.getsockopt.return_value = 0 + cls.mock_socket.fileno.return_value = 100 - # let it write out a StartupMessage - c.handle_write() + channel_patcher.start() - header = self.make_header_prefix(ServerError, stream_id=1) - body = self.make_error_body(ServerError.error_code, ServerError.summary) - c.socket.recv.return_value = self.make_msg(header, body) - c.handle_read() + cls.patchers = (socket_patcher, channel_patcher) - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, ConnectionException) - self.assertTrue(c.connected_event.is_set()) - - def test_socket_error_on_write(self, *args): - c = self.make_connection() - - # make the OptionsMessage write fail - c.socket.send.side_effect = socket_error(errno.EIO, "bad stuff!") - c.handle_write() - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) - - def test_blocking_on_write(self, *args): - c = self.make_connection() - - # make the OptionsMessage write block - c.socket.send.side_effect = socket_error(errno.EAGAIN, "socket busy") - c.handle_write() - - self.assertFalse(c.is_defunct) - - # try again with normal behavior - c.socket.send.side_effect = lambda x: len(x) - c.handle_write() - self.assertFalse(c.is_defunct) - self.assertTrue(c.socket.send.call_args is not None) - - def test_partial_send(self, *args): - c = self.make_connection() - - # only write the first four bytes of the OptionsMessage - write_size = 4 - c.socket.send.side_effect = None - c.socket.send.return_value = write_size - c.handle_write() - - msg_size = 9 # v3+ frame header - expected_writes = int(math.ceil(float(msg_size) / write_size)) - size_mod = msg_size % write_size - last_write_size = size_mod if size_mod else write_size - self.assertFalse(c.is_defunct) - self.assertEqual(expected_writes, c.socket.send.call_count) - self.assertEqual(last_write_size, len(c.socket.send.call_args[0][0])) - - def test_socket_error_on_read(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write() - - # read in a SupportedMessage response - c.socket.recv.side_effect = socket_error(errno.EIO, "busy socket") - c.handle_read() - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) - - def test_partial_header_read(self, *args): - c = self.make_connection() - - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - message = self.make_msg(header, options) - - c.socket.recv.return_value = message[0:1] - c.handle_read() - self.assertEqual(c._iobuf.getvalue(), message[0:1]) - - c.socket.recv.return_value = message[1:] - c.handle_read() - self.assertEqual(six.binary_type(), c._iobuf.getvalue()) - - # let it write out a StartupMessage - c.handle_write() - - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c.socket.recv.return_value = self.make_msg(header) - c.handle_read() - - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) - - def test_partial_message_read(self, *args): - c = self.make_connection() - - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - message = self.make_msg(header, options) - - # read in the first nine bytes - c.socket.recv.return_value = message[:9] - c.handle_read() - self.assertEqual(c._iobuf.getvalue(), message[:9]) - - # ... then read in the rest - c.socket.recv.return_value = message[9:] - c.handle_read() - self.assertEqual(six.binary_type(), c._iobuf.getvalue()) - - # let it write out a StartupMessage - c.handle_write() + @classmethod + @noop_if_monkey_patched + def tearDownClass(cls): + for p in cls.patchers: + try: + p.stop() + except: + pass - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c.socket.recv.return_value = self.make_msg(header) - c.handle_read() +has_asyncore = Version(platform.python_version()) < Version("3.12.0") +@unittest.skipUnless(has_asyncore, "asyncore has been removed in Python 3.12") +class AsyncoreConnectionTest(ReactorTestMixin, AsyncorePatcher): - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) + connection_class = AsyncoreConnection + socket_attr_name = 'socket' - def test_multi_timer_validation(self, *args): - """ - Verify that timer timeouts are honored appropriately - """ - c = self.make_connection() - # Tests timers submitted in order at various timeouts - submit_and_wait_for_completion(self, AsyncoreConnection, 0, 100, 1, 100) - # Tests timers submitted in reverse order at various timeouts - submit_and_wait_for_completion(self, AsyncoreConnection, 100, 0, -1, 100) - # Tests timers submitted in varying order at various timeouts - submit_and_wait_for_completion(self, AsyncoreConnection, 0, 100, 1, 100, True) + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") - def test_timer_cancellation(self): - """ - Verify that timer cancellation is honored - """ - # Various lists for tracking callback stage - connection = self.make_connection() - timeout = .1 - callback = TimerCallback(timeout) - timer = connection.create_timer(timeout, callback.invoke) - timer.cancel() - # Release context allow for timer thread to run. - time.sleep(.2) - timer_manager = connection._loop._timers - # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) +@unittest.skipUnless(has_asyncore, "asyncore has been removed in Python 3.12") +class TestAsyncoreTimer(TimerTestMixin, AsyncorePatcher): + connection_class = AsyncoreConnection + @property + def create_timer(self): + return self.connection.create_timer + @property + def _timers(self): + return asyncorereactor._global_loop._timers + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test asyncore with monkey patching") + super(TestAsyncoreTimer, self).setUp() diff --git a/tests/unit/io/test_eventletreactor.py b/tests/unit/io/test_eventletreactor.py index 6aa76fa790..8228884a4a 100644 --- a/tests/unit/io/test_eventletreactor.py +++ b/tests/unit/io/test_eventletreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,55 +15,64 @@ # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest + +from tests.unit.io.utils import TimerTestMixin +from tests import notpypy, EVENT_LOOP_MANAGER -from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback -from tests import is_eventlet_monkey_patched -import time +from eventlet import monkey_patch +from unittest.mock import patch try: from cassandra.io.eventletreactor import EventletConnection except ImportError: EventletConnection = None # noqa +skip_condition = EventletConnection is None or EVENT_LOOP_MANAGER != "eventlet" +# There are some issues with some versions of pypy and eventlet +@notpypy +@unittest.skipIf(skip_condition, "Skipping the eventlet tests because it's not installed") +class EventletTimerTest(TimerTestMixin, unittest.TestCase): -class EventletTimerTest(unittest.TestCase): + connection_class = EventletConnection + + @classmethod + def setUpClass(cls): + # This is run even though the class is skipped, so we need + # to make sure no monkey patching is happening + if skip_condition: + return + + # This is being added temporarily due to a bug in eventlet: + # https://github.com/eventlet/eventlet/issues/401 + import eventlet + eventlet.sleep() + monkey_patch() + # cls.connection_class = EventletConnection - def setUp(self): - if EventletConnection is None: - raise unittest.SkipTest("Eventlet libraries not available") - if not is_eventlet_monkey_patched(): - raise unittest.SkipTest("Can't test eventlet without monkey patching") EventletConnection.initialize_reactor() + assert EventletConnection._timers is not None + + def setUp(self): + socket_patcher = patch('eventlet.green.socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() + + super(EventletTimerTest, self).setUp() + + recv_patcher = patch.object(self.connection._socket, + 'recv', + return_value=b'') + self.addCleanup(recv_patcher.stop) + recv_patcher.start() + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return self.connection._timers - def test_multi_timer_validation(self, *args): - """ - Verify that timer timeouts are honored appropriately - """ - # Tests timers submitted in order at various timeouts - submit_and_wait_for_completion(self, EventletConnection, 0, 100, 1, 100) - # Tests timers submitted in reverse order at various timeouts - submit_and_wait_for_completion(self, EventletConnection, 100, 0, -1, 100) - # Tests timers submitted in varying order at various timeouts - submit_and_wait_for_completion(self, EventletConnection, 0, 100, 1, 100, True) - - def test_timer_cancellation(self): - """ - Verify that timer cancellation is honored - """ - - # Various lists for tracking callback stage - timeout = .1 - callback = TimerCallback(timeout) - timer = EventletConnection.create_timer(timeout, callback.invoke) - timer.cancel() - # Release context allow for timer thread to run. - time.sleep(.2) - timer_manager = EventletConnection._timers - # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) + # There is no unpatching because there is not a clear way + # of doing it reliably diff --git a/tests/unit/io/test_geventreactor.py b/tests/unit/io/test_geventreactor.py index 4a8ba30748..9bf0c7895f 100644 --- a/tests/unit/io/test_geventreactor.py +++ b/tests/unit/io/test_geventreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,72 +14,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest +from unittest.mock import patch -import time -from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback -from tests import is_gevent_monkey_patched, is_eventlet_monkey_patched +from tests.unit.io.utils import TimerTestMixin +from tests import EVENT_LOOP_MANAGER try: from cassandra.io.geventreactor import GeventConnection import gevent.monkey - from gevent_utils import gevent_un_patch_all except ImportError: GeventConnection = None # noqa -class GeventTimerTest(unittest.TestCase): +skip_condition = GeventConnection is None or EVENT_LOOP_MANAGER != "gevent" +@unittest.skipIf(skip_condition, "Skipping the gevent tests because it's not installed") +class GeventTimerTest(TimerTestMixin, unittest.TestCase): - need_unpatch = False + connection_class = GeventConnection @classmethod def setUpClass(cls): - if is_eventlet_monkey_patched(): - return # no dynamic patching if we have eventlet applied - if GeventConnection is not None: - if not is_gevent_monkey_patched(): - cls.need_unpatch = True - gevent.monkey.patch_all() - - @classmethod - def tearDownClass(cls): - if cls.need_unpatch: - gevent_un_patch_all() + # This is run even though the class is skipped, so we need + # to make sure no monkey patching is happening + if skip_condition: + return + # There is no unpatching because there is not a clear way + # of doing it reliably + gevent.monkey.patch_all() + GeventConnection.initialize_reactor() def setUp(self): - if not is_gevent_monkey_patched(): - raise unittest.SkipTest("Can't test gevent without monkey patching") - GeventConnection.initialize_reactor() + socket_patcher = patch('gevent.socket.socket') + self.addCleanup(socket_patcher.stop) + socket_patcher.start() - def test_multi_timer_validation(self): - """ - Verify that timer timeouts are honored appropriately - """ + super(GeventTimerTest, self).setUp() - # Tests timers submitted in order at various timeouts - submit_and_wait_for_completion(self, GeventConnection, 0, 100, 1, 100) - # Tests timers submitted in reverse order at various timeouts - submit_and_wait_for_completion(self, GeventConnection, 100, 0, -1, 100) - # Tests timers submitted in varying order at various timeouts - submit_and_wait_for_completion(self, GeventConnection, 0, 100, 1, 100, True), + recv_patcher = patch.object(self.connection._socket, + 'recv', + return_value=b'') + self.addCleanup(recv_patcher.stop) + recv_patcher.start() - def test_timer_cancellation(self): - """ - Verify that timer cancellation is honored - """ + @property + def create_timer(self): + return self.connection.create_timer - # Various lists for tracking callback stage - timeout = .1 - callback = TimerCallback(timeout) - timer = GeventConnection.create_timer(timeout, callback.invoke) - timer.cancel() - # Release context allow for timer thread to run. - time.sleep(.2) - timer_manager = GeventConnection._timers - # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) + @property + def _timers(self): + return self.connection._timers diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index 309134c940..a4050c79c1 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -1,54 +1,40 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -import errno -import math -from mock import patch, Mock -import os -import six -from six import BytesIO -from socket import error as socket_error -import sys -import time - -from cassandra.connection import (HEADER_DIRECTION_TO_CLIENT, - ConnectionException, ProtocolError) - -from cassandra.protocol import (write_stringmultimap, write_int, write_string, - SupportedMessage, ReadyMessage, ServerError) -from cassandra.marshal import uint8_pack, uint32_pack, int32_pack -from tests.unit.io.utils import TimerCallback -from tests.unit.io.utils import submit_and_wait_for_completion -from tests import is_monkey_patched +import unittest +from unittest.mock import patch, Mock +import weakref +import socket +from cassandra import DependencyException try: + from cassandra.io.libevreactor import _cleanup as libev__cleanup from cassandra.io.libevreactor import LibevConnection -except ImportError: +except DependencyException: LibevConnection = None # noqa +from tests import is_monkey_patched +from tests.unit.io.utils import ReactorTestMixin, TimerTestMixin, noop_if_monkey_patched + + +class LibevConnectionTest(ReactorTestMixin, unittest.TestCase): -@patch('socket.socket') -@patch('cassandra.io.libevwrapper.IO') -@patch('cassandra.io.libevwrapper.Prepare') -@patch('cassandra.io.libevwrapper.Async') -@patch('cassandra.io.libevreactor.LibevLoop.maybe_start') -class LibevConnectionTest(unittest.TestCase): + connection_class = LibevConnection + socket_attr_name = '_socket' + null_handle_function_args = None, 0 def setUp(self): if is_monkey_patched(): @@ -57,242 +43,101 @@ def setUp(self): raise unittest.SkipTest('libev does not appear to be installed correctly') LibevConnection.initialize_reactor() + # we patch here rather than as a decorator so that the Mixin can avoid + # specifying patch args to test methods + patchers = [patch(obj) for obj in + ('socket.socket', + 'cassandra.io.libevwrapper.IO', + 'cassandra.io.libevreactor.LibevLoop.maybe_start' + )] + for p in patchers: + self.addCleanup(p.stop) + for p in patchers: + p.start() + + def test_watchers_are_finished(self): + """ + Test for asserting that watchers are closed in LibevConnection + + This test simulates a process termination without calling cluster.shutdown(), which would trigger + _global_loop._cleanup. It will check the watchers have been closed + Finally it will restore the LibevConnection reactor so it doesn't affect + the rest of the tests + + @since 3.10 + @jira_ticket PYTHON-747 + @expected_result the watchers are closed + + @test_category connection + """ + from cassandra.io.libevreactor import _global_loop + with patch.object(_global_loop, "_thread"),\ + patch.object(_global_loop, "notify"): + + self.make_connection() + + # We have to make a copy because the connections shouldn't + # be alive when we verify them + live_connections = set(_global_loop._live_conns) + + # This simulates the process ending without cluster.shutdown() + # being called, then with atexit _cleanup for libevreactor would + # be called + libev__cleanup(_global_loop) + for conn in live_connections: + self.assertTrue(conn._write_watcher.stop.mock_calls) + self.assertTrue(conn._read_watcher.stop.mock_calls) + + _global_loop._shutdown = False + + +class LibevTimerPatcher(unittest.TestCase): + + @classmethod + @noop_if_monkey_patched + def setUpClass(cls): + if LibevConnection is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + cls.patchers = [ + patch('socket.socket', spec=socket.socket), + patch('cassandra.io.libevwrapper.IO') + ] + for p in cls.patchers: + p.start() + + @classmethod + @noop_if_monkey_patched + def tearDownClass(cls): + for p in cls.patchers: + try: + p.stop() + except: + pass + + +class LibevTimerTest(TimerTestMixin, LibevTimerPatcher): + connection_class = LibevConnection + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + from cassandra.io.libevreactor import _global_loop + return _global_loop._timers + def make_connection(self): c = LibevConnection('1.2.3.4', cql_version='3.0.1') - c._socket = Mock() - c._socket.send.side_effect = lambda x: len(x) + c._socket_impl = Mock() + c._socket.return_value.send.side_effect = lambda x: len(x) return c - def make_header_prefix(self, message_class, version=2, stream_id=0): - return six.binary_type().join(map(uint8_pack, [ - 0xff & (HEADER_DIRECTION_TO_CLIENT | version), - 0, # flags (compression) - stream_id, - message_class.opcode # opcode - ])) - - def make_options_body(self): - options_buf = BytesIO() - write_stringmultimap(options_buf, { - 'CQL_VERSION': ['3.0.1'], - 'COMPRESSION': [] - }) - return options_buf.getvalue() - - def make_error_body(self, code, msg): - buf = BytesIO() - write_int(buf, code) - write_string(buf, msg) - return buf.getvalue() - - def make_msg(self, header, body=six.binary_type()): - return header + uint32_pack(len(body)) + body - - def test_successful_connection(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write(None, 0) - - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - c._socket.recv.return_value = self.make_msg(header, options) - c.handle_read(None, 0) - - # let it write out a StartupMessage - c.handle_write(None, 0) - - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c._socket.recv.return_value = self.make_msg(header) - c.handle_read(None, 0) - - self.assertTrue(c.connected_event.is_set()) - return c - - def test_egain_on_buffer_size(self, *args): - # get a connection that's already fully started - c = self.test_successful_connection() - - header = six.b('\x00\x00\x00\x00') + int32_pack(20000) - responses = [ - header + (six.b('a') * (4096 - len(header))), - six.b('a') * 4096, - socket_error(errno.EAGAIN), - six.b('a') * 100, - socket_error(errno.EAGAIN)] - - def side_effect(*args): - response = responses.pop(0) - if isinstance(response, socket_error): - raise response - else: - return response - - c._socket.recv.side_effect = side_effect - c.handle_read(None, 0) - self.assertEqual(c._current_frame.end_pos, 20000 + len(header)) - # the EAGAIN prevents it from reading the last 100 bytes - c._iobuf.seek(0, os.SEEK_END) - pos = c._iobuf.tell() - self.assertEqual(pos, 4096 + 4096) - - # now tell it to read the last 100 bytes - c.handle_read(None, 0) - c._iobuf.seek(0, os.SEEK_END) - pos = c._iobuf.tell() - self.assertEqual(pos, 4096 + 4096 + 100) - - def test_protocol_error(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write(None, 0) - - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage, version=0xa4) - options = self.make_options_body() - c._socket.recv.return_value = self.make_msg(header, options) - c.handle_read(None, 0) - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertTrue(c.connected_event.is_set()) - self.assertIsInstance(c.last_error, ProtocolError) - - def test_error_message_on_startup(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write(None, 0) - - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - c._socket.recv.return_value = self.make_msg(header, options) - c.handle_read(None, 0) - - # let it write out a StartupMessage - c.handle_write(None, 0) - - header = self.make_header_prefix(ServerError, stream_id=1) - body = self.make_error_body(ServerError.error_code, ServerError.summary) - c._socket.recv.return_value = self.make_msg(header, body) - c.handle_read(None, 0) - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, ConnectionException) - self.assertTrue(c.connected_event.is_set()) - - def test_socket_error_on_write(self, *args): - c = self.make_connection() - - # make the OptionsMessage write fail - c._socket.send.side_effect = socket_error(errno.EIO, "bad stuff!") - c.handle_write(None, 0) - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) - - def test_blocking_on_write(self, *args): - c = self.make_connection() - - # make the OptionsMessage write block - c._socket.send.side_effect = socket_error(errno.EAGAIN, "socket busy") - c.handle_write(None, 0) - - self.assertFalse(c.is_defunct) - - # try again with normal behavior - c._socket.send.side_effect = lambda x: len(x) - c.handle_write(None, 0) - self.assertFalse(c.is_defunct) - self.assertTrue(c._socket.send.call_args is not None) - - def test_partial_send(self, *args): - c = self.make_connection() - - # only write the first four bytes of the OptionsMessage - write_size = 4 - c._socket.send.side_effect = None - c._socket.send.return_value = write_size - c.handle_write(None, 0) - - msg_size = 9 # v3+ frame header - expected_writes = int(math.ceil(float(msg_size) / write_size)) - size_mod = msg_size % write_size - last_write_size = size_mod if size_mod else write_size - self.assertFalse(c.is_defunct) - self.assertEqual(expected_writes, c._socket.send.call_count) - self.assertEqual(last_write_size, len(c._socket.send.call_args[0][0])) - - def test_socket_error_on_read(self, *args): - c = self.make_connection() - - # let it write the OptionsMessage - c.handle_write(None, 0) - - # read in a SupportedMessage response - c._socket.recv.side_effect = socket_error(errno.EIO, "busy socket") - c.handle_read(None, 0) - - # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) - - def test_partial_header_read(self, *args): - c = self.make_connection() - - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - message = self.make_msg(header, options) - - # read in the first byte - c._socket.recv.return_value = message[0:1] - c.handle_read(None, 0) - self.assertEqual(c._iobuf.getvalue(), message[0:1]) - - c._socket.recv.return_value = message[1:] - c.handle_read(None, 0) - self.assertEqual(six.binary_type(), c._iobuf.getvalue()) - - # let it write out a StartupMessage - c.handle_write(None, 0) - - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c._socket.recv.return_value = self.make_msg(header) - c.handle_read(None, 0) - - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) - - def test_partial_message_read(self, *args): - c = self.make_connection() - - header = self.make_header_prefix(SupportedMessage) - options = self.make_options_body() - message = self.make_msg(header, options) - - # read in the first nine bytes - c._socket.recv.return_value = message[:9] - c.handle_read(None, 0) - self.assertEqual(c._iobuf.getvalue(), message[:9]) - - # ... then read in the rest - c._socket.recv.return_value = message[9:] - c.handle_read(None, 0) - self.assertEqual(six.binary_type(), c._iobuf.getvalue()) - - # let it write out a StartupMessage - c.handle_write(None, 0) - - header = self.make_header_prefix(ReadyMessage, stream_id=1) - c._socket.recv.return_value = self.make_msg(header) - c.handle_read(None, 0) + def setUp(self): + if is_monkey_patched(): + raise unittest.SkipTest("Can't test libev with monkey patching.") + if LibevConnection is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) + LibevConnection.initialize_reactor() + super(LibevTimerTest, self).setUp() diff --git a/tests/unit/io/test_libevtimer.py b/tests/unit/io/test_libevtimer.py deleted file mode 100644 index 4b21d03170..0000000000 --- a/tests/unit/io/test_libevtimer.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2013-2016 DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - - -from mock import patch, Mock - -import time - -from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback -from tests import is_monkey_patched - - -try: - from cassandra.io.libevreactor import LibevConnection -except ImportError: - LibevConnection = None # noqa - - -@patch('socket.socket') -@patch('cassandra.io.libevwrapper.IO') -class LibevTimerTest(unittest.TestCase): - - def setUp(self): - if is_monkey_patched(): - raise unittest.SkipTest("Can't test libev with monkey patching") - if LibevConnection is None: - raise unittest.SkipTest('libev does not appear to be installed correctly') - LibevConnection.initialize_reactor() - - def make_connection(self): - c = LibevConnection('1.2.3.4', cql_version='3.0.1') - c._socket = Mock() - c._socket.send.side_effect = lambda x: len(x) - return c - - def test_multi_timer_validation(self, *args): - """ - Verify that timer timeouts are honored appropriately - """ - c = self.make_connection() - c.initialize_reactor() - # Tests timers submitted in order at various timeouts - submit_and_wait_for_completion(self, c, 0, 100, 1, 100) - # Tests timers submitted in reverse order at various timeouts - submit_and_wait_for_completion(self, c, 100, 0, -1, 100) - # Tests timers submitted in varying order at various timeouts - submit_and_wait_for_completion(self, c, 0, 100, 1, 100, True) - - def test_timer_cancellation(self, *args): - """ - Verify that timer cancellation is honored - """ - - # Various lists for tracking callback stage - connection = self.make_connection() - timeout = .1 - callback = TimerCallback(timeout) - timer = connection.create_timer(timeout, callback.invoke) - timer.cancel() - # Release context allow for timer thread to run. - time.sleep(.2) - timer_manager = connection._libevloop._timers - # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) - diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index fd4181460a..67c4d8eaf3 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,68 +14,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest -from mock import Mock, patch -import time +import unittest +from unittest.mock import Mock, patch + +from cassandra.connection import DefaultEndPoint try: from twisted.test import proto_helpers from twisted.python.failure import Failure from cassandra.io import twistedreactor + from cassandra.io.twistedreactor import TwistedConnection except ImportError: - twistedreactor = None # NOQA + twistedreactor = TwistedConnection = None # NOQA + from cassandra.connection import _Frame -from tests.unit.io.utils import submit_and_wait_for_completion, TimerCallback +from tests.unit.io.utils import TimerTestMixin -class TestTwistedTimer(unittest.TestCase): +class TestTwistedTimer(TimerTestMixin, unittest.TestCase): """ Simple test class that is used to validate that the TimerManager, and timer classes function appropriately with the twisted infrastructure """ + connection_class = TwistedConnection + + @property + def create_timer(self): + return self.connection.create_timer + + @property + def _timers(self): + return self.connection._loop._timers + def setUp(self): if twistedreactor is None: raise unittest.SkipTest("Twisted libraries not available") twistedreactor.TwistedConnection.initialize_reactor() - - def test_multi_timer_validation(self): - """ - Verify that the timers are called in the correct order - """ - twistedreactor.TwistedConnection.initialize_reactor() - connection = twistedreactor.TwistedConnection('1.2.3.4', - cql_version='3.0.1') - # Tests timers submitted in order at various timeouts - submit_and_wait_for_completion(self, connection, 0, 100, 1, 100) - # Tests timers submitted in reverse order at various timeouts - submit_and_wait_for_completion(self, connection, 100, 0, -1, 100) - # Tests timers submitted in varying order at various timeouts - submit_and_wait_for_completion(self, connection, 0, 100, 1, 100, True) - - def test_timer_cancellation(self, *args): - """ - Verify that timer cancellation is honored - """ - - # Various lists for tracking callback stage - connection = twistedreactor.TwistedConnection('1.2.3.4', - cql_version='3.0.1') - timeout = .1 - callback = TimerCallback(timeout) - timer = connection.create_timer(timeout, callback.invoke) - timer.cancel() - # Release context allow for timer thread to run. - time.sleep(.2) - timer_manager = connection._loop._timers - # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) + super(TestTwistedTimer, self).setUp() class TestTwistedProtocol(unittest.TestCase): @@ -85,13 +64,13 @@ def setUp(self): self.tr = proto_helpers.StringTransportWithDisconnection() self.tr.connector = Mock() self.mock_connection = Mock() - self.tr.connector.factory = twistedreactor.TwistedConnectionClientFactory( - self.mock_connection) - self.obj_ut = twistedreactor.TwistedConnectionProtocol() + self.obj_ut = twistedreactor.TwistedConnectionProtocol(self.mock_connection) self.tr.protocol = self.obj_ut def tearDown(self): - pass + loop = twistedreactor.TwistedConnection._loop + if not loop._reactor_stopped(): + loop._cleanup() def test_makeConnection(self): """ @@ -112,32 +91,6 @@ def test_receiving_data(self): self.mock_connection._iobuf.write.assert_called_with("foobar") -class TestTwistedClientFactory(unittest.TestCase): - def setUp(self): - if twistedreactor is None: - raise unittest.SkipTest("Twisted libraries not available") - twistedreactor.TwistedConnection.initialize_reactor() - self.mock_connection = Mock() - self.obj_ut = twistedreactor.TwistedConnectionClientFactory( - self.mock_connection) - - def test_client_connection_failed(self): - """ - Verify that connection failed causes the connection object to close. - """ - exc = Exception('a test') - self.obj_ut.clientConnectionFailed(None, Failure(exc)) - self.mock_connection.defunct.assert_called_with(exc) - - def test_client_connection_lost(self): - """ - Verify that connection lost causes the connection object to close. - """ - exc = Exception('a test') - self.obj_ut.clientConnectionLost(None, Failure(exc)) - self.mock_connection.defunct.assert_called_with(exc) - - class TestTwistedConnection(unittest.TestCase): def setUp(self): if twistedreactor is None: @@ -148,7 +101,7 @@ def setUp(self): self.reactor_run_patcher = patch('twisted.internet.reactor.run') self.mock_reactor_cft = self.reactor_cft_patcher.start() self.mock_reactor_run = self.reactor_run_patcher.start() - self.obj_ut = twistedreactor.TwistedConnection('1.2.3.4', + self.obj_ut = twistedreactor.TwistedConnection(DefaultEndPoint('1.2.3.4'), cql_version='3.0.1') def tearDown(self): @@ -163,22 +116,13 @@ def test_connection_initialization(self): self.obj_ut._loop._cleanup() self.mock_reactor_run.assert_called_with(installSignalHandlers=False) - @patch('twisted.internet.reactor.connectTCP') - def test_add_connection(self, mock_connectTCP): - """ - Verify that add_connection() gives us a valid twisted connector. - """ - self.obj_ut.add_connection() - self.assertTrue(self.obj_ut.connector is not None) - self.assertTrue(mock_connectTCP.called) - def test_client_connection_made(self): """ Verifiy that _send_options_message() is called in client_connection_made() """ self.obj_ut._send_options_message = Mock() - self.obj_ut.client_connection_made() + self.obj_ut.client_connection_made(Mock()) self.obj_ut._send_options_message.assert_called_with() @patch('twisted.internet.reactor.connectTCP') @@ -186,11 +130,13 @@ def test_close(self, mock_connectTCP): """ Verify that close() disconnects the connector and errors callbacks. """ + transport = Mock() self.obj_ut.error_all_requests = Mock() self.obj_ut.add_connection() + self.obj_ut.client_connection_made(transport) self.obj_ut.is_closed = False self.obj_ut.close() - self.obj_ut.connector.disconnect.assert_called_with() + self.assertTrue(self.obj_ut.connected_event.is_set()) self.assertTrue(self.obj_ut.error_all_requests.called) @@ -203,12 +149,12 @@ def test_handle_read__incomplete(self): # incomplete header self.obj_ut._iobuf.write(b'\x84\x00\x00\x00\x00') self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._iobuf.getvalue(), b'\x84\x00\x00\x00\x00') + self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), b'\x84\x00\x00\x00\x00') # full header, but incomplete body self.obj_ut._iobuf.write(b'\x00\x00\x00\x15') self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._iobuf.getvalue(), + self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), b'\x84\x00\x00\x00\x00\x00\x00\x00\x15') self.assertEqual(self.obj_ut._current_frame.end_pos, 30) @@ -229,9 +175,9 @@ def test_handle_read__fullmessage(self): self.obj_ut._iobuf.write( b'\x84\x01\x00\x02\x03\x00\x00\x00\x15' + body + extra) self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._iobuf.getvalue(), extra) + self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), extra) self.obj_ut.process_msg.assert_called_with( - _Frame(version=4, flags=1, stream=2, opcode=3, body_offset=9, end_pos= 9 + len(body)), body) + _Frame(version=4, flags=1, stream=2, opcode=3, body_offset=9, end_pos=9 + len(body)), body) @patch('twisted.internet.reactor.connectTCP') def test_push(self, mock_connectTCP): @@ -239,7 +185,8 @@ def test_push(self, mock_connectTCP): Verifiy that push() calls transport.write(data). """ self.obj_ut.add_connection() + transport_mock = Mock() + self.obj_ut.transport = transport_mock self.obj_ut.push('123 pickup') self.mock_reactor_cft.assert_called_with( - self.obj_ut.connector.transport.write, '123 pickup') - + transport_mock.write, '123 pickup') diff --git a/tests/unit/io/utils.py b/tests/unit/io/utils.py index 58ed78ea26..d4483d08c7 100644 --- a/tests/unit/io/utils.py +++ b/tests/unit/io/utils.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,9 +14,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cassandra.connection import ( + ConnectionException, ProtocolError, HEADER_DIRECTION_TO_CLIENT +) +from cassandra.marshal import uint8_pack, uint32_pack +from cassandra.protocol import ( + write_stringmultimap, write_int, write_string, SupportedMessage, ReadyMessage, ServerError +) +from cassandra.connection import DefaultEndPoint +from tests import is_monkey_patched + +import io +import random +from functools import wraps +from itertools import cycle +from io import BytesIO +from unittest.mock import Mock + +import errno +import logging +import math +import os +from socket import error as socket_error +import ssl + +import unittest + import time +log = logging.getLogger(__name__) + + class TimerCallback(object): invoked = False @@ -68,7 +99,7 @@ def get_timeout(gross_time, start, end, precision, split_range): return timeout -def submit_and_wait_for_completion(unit_test, connection, start, end, increment, precision, split_range=False): +def submit_and_wait_for_completion(unit_test, create_timer, start, end, increment, precision, split_range=False): """ This will submit a number of timers to the provided connection. It will then ensure that the corresponding callback is invoked in the appropriate amount of time. @@ -89,7 +120,7 @@ def submit_and_wait_for_completion(unit_test, connection, start, end, increment, for gross_time in range(start, end, increment): timeout = get_timeout(gross_time, start, end, precision, split_range) callback = TimerCallback(timeout) - connection.create_timer(timeout, callback.invoke) + create_timer(timeout, callback.invoke) pending_callbacks.append(callback) # wait for all the callbacks associated with the timers to be invoked @@ -103,3 +134,382 @@ def submit_and_wait_for_completion(unit_test, connection, start, end, increment, # ensure they are all called back in a timely fashion for callback in completed_callbacks: unit_test.assertAlmostEqual(callback.expected_wait, callback.get_wait_time(), delta=.15) + + +def noop_if_monkey_patched(f): + if is_monkey_patched(): + @wraps(f) + def noop(*args, **kwargs): + return + return noop + + return f + + +class TimerTestMixin(object): + + connection_class = connection = None + # replace with property returning the connection's create_timer and _timers + create_timer = _timers = None + + def setUp(self): + self.connection = self.connection_class( + DefaultEndPoint("127.0.0.1"), + connect_timeout=5 + ) + + def tearDown(self): + self.connection.close() + + def test_multi_timer_validation(self): + """ + Verify that timer timeouts are honored appropriately + """ + # Tests timers submitted in order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 0, 100, 1, 100) + # Tests timers submitted in reverse order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 100, 0, -1, 100) + # Tests timers submitted in varying order at various timeouts + submit_and_wait_for_completion(self, self.create_timer, 0, 100, 1, 100, True), + + def test_timer_cancellation(self): + """ + Verify that timer cancellation is honored + """ + + # Various lists for tracking callback stage + timeout = .1 + callback = TimerCallback(timeout) + timer = self.create_timer(timeout, callback.invoke) + timer.cancel() + # Release context allow for timer thread to run. + time.sleep(timeout * 2) + timer_manager = self._timers + # Assert that the cancellation was honored + self.assertFalse(timer_manager._queue) + self.assertFalse(timer_manager._new_timers) + self.assertFalse(callback.was_invoked()) + + +class ReactorTestMixin(object): + + connection_class = socket_attr_name = None + null_handle_function_args = () + + def get_socket(self, connection): + return getattr(connection, self.socket_attr_name) + + def set_socket(self, connection, obj): + return setattr(connection, self.socket_attr_name, obj) + + def make_header_prefix(self, message_class, version=2, stream_id=0): + return bytes().join(map(uint8_pack, [ + 0xff & (HEADER_DIRECTION_TO_CLIENT | version), + 0, # flags (compression) + stream_id, + message_class.opcode # opcode + ])) + + def make_connection(self): + c = self.connection_class(DefaultEndPoint('1.2.3.4'), cql_version='3.0.1', connect_timeout=5) + mocket = Mock() + mocket.send.side_effect = lambda x: len(x) + self.set_socket(c, mocket) + return c + + def make_options_body(self): + options_buf = BytesIO() + write_stringmultimap(options_buf, { + 'CQL_VERSION': ['3.0.1'], + 'COMPRESSION': [] + }) + return options_buf.getvalue() + + def make_error_body(self, code, msg): + buf = BytesIO() + write_int(buf, code) + write_string(buf, msg) + return buf.getvalue() + + def make_msg(self, header, body=bytes()): + return header + uint32_pack(len(body)) + body + + def test_successful_connection(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + return c + + def test_eagain_on_buffer_size(self): + self._check_error_recovery_on_buffer_size(errno.EAGAIN) + + def test_ewouldblock_on_buffer_size(self): + self._check_error_recovery_on_buffer_size(errno.EWOULDBLOCK) + + def test_sslwantread_on_buffer_size(self): + self._check_error_recovery_on_buffer_size( + ssl.SSL_ERROR_WANT_READ, + error_class=ssl.SSLError) + + def test_sslwantwrite_on_buffer_size(self): + self._check_error_recovery_on_buffer_size( + ssl.SSL_ERROR_WANT_WRITE, + error_class=ssl.SSLError) + + def _check_error_recovery_on_buffer_size(self, error_code, error_class=socket_error): + c = self.test_successful_connection() + + # current data, used by the recv side_effect + message_chunks = None + + def recv_side_effect(*args): + response = message_chunks.pop(0) + if isinstance(response, error_class): + raise response + else: + return response + + # setup + self.get_socket(c).recv.side_effect = recv_side_effect + c.process_io_buffer = Mock() + + def chunk(size): + return b'a' * size + + buf_size = c.in_buffer_size + + # List of messages to test. A message = (chunks, expected_read_size) + messages = [ + ([chunk(200)], 200), + ([chunk(200), chunk(200)], 200), # first chunk < in_buffer_size, process the message + ([chunk(buf_size), error_class(error_code)], buf_size), + ([chunk(buf_size), chunk(buf_size), error_class(error_code)], buf_size*2), + ([chunk(buf_size), chunk(buf_size), chunk(10)], (buf_size*2) + 10), + ([chunk(buf_size), chunk(buf_size), error_class(error_code), chunk(10)], buf_size*2), + ([error_class(error_code), chunk(buf_size)], 0) + ] + + for message, expected_size in messages: + message_chunks = message + c._io_buffer._io_buffer = io.BytesIO() + c.process_io_buffer.reset_mock() + c.handle_read(*self.null_handle_function_args) + c._io_buffer.io_buffer.seek(0, os.SEEK_END) + + # Ensure the message size is the good one and that the + # message has been processed if it is non-empty + self.assertEqual(c._io_buffer.io_buffer.tell(), expected_size) + if expected_size == 0: + c.process_io_buffer.assert_not_called() + else: + c.process_io_buffer.assert_called_once_with() + + def test_protocol_error(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage, version=0xa4) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertTrue(c.connected_event.is_set()) + self.assertIsInstance(c.last_error, ProtocolError) + + def test_error_message_on_startup(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + self.get_socket(c).recv.return_value = self.make_msg(header, options) + c.handle_read(*self.null_handle_function_args) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ServerError, stream_id=1) + body = self.make_error_body(ServerError.error_code, ServerError.summary) + self.get_socket(c).recv.return_value = self.make_msg(header, body) + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, ConnectionException) + self.assertTrue(c.connected_event.is_set()) + + def test_socket_error_on_write(self): + c = self.make_connection() + + # make the OptionsMessage write fail + self.get_socket(c).send.side_effect = socket_error(errno.EIO, "bad stuff!") + c.handle_write(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_blocking_on_write(self): + c = self.make_connection() + + # make the OptionsMessage write block + self.get_socket(c).send.side_effect = socket_error(errno.EAGAIN, + "socket busy") + c.handle_write(*self.null_handle_function_args) + + self.assertFalse(c.is_defunct) + + # try again with normal behavior + self.get_socket(c).send.side_effect = lambda x: len(x) + c.handle_write(*self.null_handle_function_args) + self.assertFalse(c.is_defunct) + self.assertTrue(self.get_socket(c).send.call_args is not None) + + def test_partial_send(self): + c = self.make_connection() + + # only write the first four bytes of the OptionsMessage + write_size = 4 + self.get_socket(c).send.side_effect = None + self.get_socket(c).send.return_value = write_size + c.handle_write(*self.null_handle_function_args) + + msg_size = 9 # v3+ frame header + expected_writes = int(math.ceil(float(msg_size) / write_size)) + size_mod = msg_size % write_size + last_write_size = size_mod if size_mod else write_size + self.assertFalse(c.is_defunct) + self.assertEqual(expected_writes, self.get_socket(c).send.call_count) + self.assertEqual(last_write_size, + len(self.get_socket(c).send.call_args[0][0])) + + def test_socket_error_on_read(self): + c = self.make_connection() + + # let it write the OptionsMessage + c.handle_write(*self.null_handle_function_args) + + # read in a SupportedMessage response + self.get_socket(c).recv.side_effect = socket_error(errno.EIO, + "busy socket") + c.handle_read(*self.null_handle_function_args) + + # make sure it errored correctly + self.assertTrue(c.is_defunct) + self.assertIsInstance(c.last_error, socket_error) + self.assertTrue(c.connected_event.is_set()) + + def test_partial_header_read(self): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + self.get_socket(c).recv.return_value = message[0:1] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[0:1]) + + self.get_socket(c).recv.return_value = message[1:] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(bytes(), c._io_buffer.io_buffer.getvalue()) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct) + + def test_partial_message_read(self): + c = self.make_connection() + + header = self.make_header_prefix(SupportedMessage) + options = self.make_options_body() + message = self.make_msg(header, options) + + # read in the first nine bytes + self.get_socket(c).recv.return_value = message[:9] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[:9]) + + # ... then read in the rest + self.get_socket(c).recv.return_value = message[9:] + c.handle_read(*self.null_handle_function_args) + self.assertEqual(bytes(), c._io_buffer.io_buffer.getvalue()) + + # let it write out a StartupMessage + c.handle_write(*self.null_handle_function_args) + + header = self.make_header_prefix(ReadyMessage, stream_id=1) + self.get_socket(c).recv.return_value = self.make_msg(header) + c.handle_read(*self.null_handle_function_args) + + self.assertTrue(c.connected_event.is_set()) + self.assertFalse(c.is_defunct) + + def test_mixed_message_and_buffer_sizes(self): + """ + Validate that all messages are processed with different scenarios: + + - various message sizes + - various socket buffer sizes + - random non-fatal errors raised + """ + c = self.make_connection() + c.process_io_buffer = Mock() + + errors = cycle([ + ssl.SSLError(ssl.SSL_ERROR_WANT_READ), + ssl.SSLError(ssl.SSL_ERROR_WANT_WRITE), + socket_error(errno.EWOULDBLOCK), + socket_error(errno.EAGAIN) + ]) + + for buffer_size in [512, 1024, 2048, 4096, 8192]: + c.in_buffer_size = buffer_size + + for i in range(1, 15): + c.process_io_buffer.reset_mock() + c._io_buffer._io_buffer = io.BytesIO() + message = io.BytesIO(b'a' * (2**i)) + + def recv_side_effect(*args): + if random.randint(1,10) % 3 == 0: + raise next(errors) + return message.read(args[0]) + + self.get_socket(c).recv.side_effect = recv_side_effect + c.handle_read(*self.null_handle_function_args) + if c._io_buffer.io_buffer.tell(): + c.process_io_buffer.assert_called_once() + else: + c.process_io_buffer.assert_not_called() diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 0000000000..49607d4e48 --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# # Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.auth import PlainTextAuthenticator + +import unittest + + +class TestPlainTextAuthenticator(unittest.TestCase): + + def test_evaluate_challenge_with_unicode_data(self): + authenticator = PlainTextAuthenticator("johnӁ", "doeӁ") + self.assertEqual( + authenticator.evaluate_challenge(b'PLAIN-START'), + "\x00johnӁ\x00doeӁ".encode('utf-8') + ) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 763875c9f8..69a65855a0 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -1,27 +1,154 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest + +import logging +import socket +import uuid + +from unittest.mock import patch, Mock + +from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ + InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion +from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ + ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import SniEndPoint, SniEndPointFactory +from cassandra.pool import Host +from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy +from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory +from tests.unit.utils import mock_session_pools +from tests import connection_class + + +log = logging.getLogger(__name__) + + +class ExceptionTypeTest(unittest.TestCase): + + def test_exception_types(self): + """ + PYTHON-443 + Sanity check to ensure we don't unintentionally change class hierarchy of exception types + """ + self.assertTrue(issubclass(Unavailable, DriverException)) + self.assertTrue(issubclass(Unavailable, RequestExecutionException)) + + self.assertTrue(issubclass(ReadTimeout, DriverException)) + self.assertTrue(issubclass(ReadTimeout, RequestExecutionException)) + self.assertTrue(issubclass(ReadTimeout, Timeout)) + + self.assertTrue(issubclass(WriteTimeout, DriverException)) + self.assertTrue(issubclass(WriteTimeout, RequestExecutionException)) + self.assertTrue(issubclass(WriteTimeout, Timeout)) + + self.assertTrue(issubclass(CoordinationFailure, DriverException)) + self.assertTrue(issubclass(CoordinationFailure, RequestExecutionException)) + + self.assertTrue(issubclass(ReadFailure, DriverException)) + self.assertTrue(issubclass(ReadFailure, RequestExecutionException)) + self.assertTrue(issubclass(ReadFailure, CoordinationFailure)) + + self.assertTrue(issubclass(WriteFailure, DriverException)) + self.assertTrue(issubclass(WriteFailure, RequestExecutionException)) + self.assertTrue(issubclass(WriteFailure, CoordinationFailure)) + + self.assertTrue(issubclass(FunctionFailure, DriverException)) + self.assertTrue(issubclass(FunctionFailure, RequestExecutionException)) -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa + self.assertTrue(issubclass(RequestValidationException, DriverException)) -from mock import patch, Mock + self.assertTrue(issubclass(ConfigurationException, DriverException)) + self.assertTrue(issubclass(ConfigurationException, RequestValidationException)) -from cassandra import ConsistencyLevel -from cassandra.cluster import _Scheduler, Session -from cassandra.query import SimpleStatement + self.assertTrue(issubclass(AlreadyExists, DriverException)) + self.assertTrue(issubclass(AlreadyExists, RequestValidationException)) + self.assertTrue(issubclass(AlreadyExists, ConfigurationException)) + + self.assertTrue(issubclass(InvalidRequest, DriverException)) + self.assertTrue(issubclass(InvalidRequest, RequestValidationException)) + + self.assertTrue(issubclass(Unauthorized, DriverException)) + self.assertTrue(issubclass(Unauthorized, RequestValidationException)) + + self.assertTrue(issubclass(AuthenticationFailed, DriverException)) + + self.assertTrue(issubclass(OperationTimedOut, DriverException)) + + self.assertTrue(issubclass(UnsupportedOperation, DriverException)) + + +class MockOrderedPolicy(RoundRobinPolicy): + all_hosts = set() + + def make_query_plan(self, working_keyspace=None, query=None): + return sorted(self.all_hosts, key=lambda x: x.endpoint.ssl_options['server_hostname']) + +class ClusterTest(unittest.TestCase): + + def test_tuple_for_contact_points(self): + cluster = Cluster(contact_points=[('localhost', 9045), ('127.0.0.2', 9046), '127.0.0.3'], port=9999) + localhost_addr = set([addr[0] for addr in [t for (_,_,_,_,t) in socket.getaddrinfo("localhost",80)]]) + for cp in cluster.endpoints_resolved: + if cp.address in localhost_addr: + self.assertEqual(cp.port, 9045) + elif cp.address == '127.0.0.2': + self.assertEqual(cp.port, 9046) + else: + self.assertEqual(cp.address, '127.0.0.3') + self.assertEqual(cp.port, 9999) + + def test_invalid_contact_point_types(self): + with self.assertRaises(ValueError): + Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) + with self.assertRaises(TypeError): + Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) + + def test_requests_in_flight_threshold(self): + d = HostDistance.LOCAL + mn = 3 + mx = 5 + c = Cluster(protocol_version=2) + c.set_min_requests_per_connection(d, mn) + c.set_max_requests_per_connection(d, mx) + # min underflow, max, overflow + for n in (-1, mx, 127): + self.assertRaises(ValueError, c.set_min_requests_per_connection, d, n) + # max underflow, under min, overflow + for n in (0, mn, 128): + self.assertRaises(ValueError, c.set_max_requests_per_connection, d, n) + + # Validate that at least the default LBP can create a query plan with end points that resolve + # to different addresses initially. This may not be exactly how things play out in practice + # (the control connection will muck with this even if nothing else does) but it should be + # a pretty good approximation. + def test_query_plan_for_sni_contains_unique_addresses(self): + node_cnt = 5 + def _mocked_proxy_dns_resolution(self): + return [(socket.AF_UNIX, socket.SOCK_STREAM, 0, None, ('127.0.0.%s' % (i,), 9042)) for i in range(node_cnt)] + + c = Cluster() + lbp = c.load_balancing_policy + lbp.local_dc = "dc1" + factory = SniEndPointFactory("proxy.foo.bar", 9042) + for host in (Host(factory.create({"host_id": uuid.uuid4().hex, "dc": "dc1"}), SimpleConvictionPolicy) for _ in range(node_cnt)): + lbp.on_up(host) + with patch.object(SniEndPoint, '_resolve_proxy_addresses', _mocked_proxy_dns_resolution): + addrs = [host.endpoint.resolve() for host in lbp.make_query_plan()] + # single SNI endpoint should be resolved to multiple unique IP addresses + self.assertEqual(len(addrs), len(set(addrs))) class SchedulerTest(unittest.TestCase): @@ -29,7 +156,7 @@ class SchedulerTest(unittest.TestCase): @patch('time.time', return_value=3) # always queue at same time @patch('cassandra.cluster._Scheduler.run') # don't actually run the thread - def test_event_delay_timing(self, *args): + def test_event_delay_timing(self, *_): """ Schedule something with a time collision to make sure the heap comparison works @@ -41,31 +168,446 @@ def test_event_delay_timing(self, *args): class SessionTest(unittest.TestCase): - # TODO: this suite could be expanded; for now just adding a test covering a PR + def setUp(self): + if connection_class is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + connection_class.initialize_reactor() - @patch('cassandra.cluster.ResponseFuture._make_query_plan') - def test_default_serial_consistency_level(self, *args): + # TODO: this suite could be expanded; for now just adding a test covering a PR + @mock_session_pools + def test_default_serial_consistency_level_ep(self, *_): """ - Make sure default_serial_consistency_level passes through to a query message. + Make sure default_serial_consistency_level passes through to a query message using execution profiles. Also make sure Statement.serial_consistency_level overrides the default. PR #510 """ - s = Session(Mock(protocol_version=4), []) + c = Cluster(protocol_version=4) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) # default is None - self.assertIsNone(s.default_serial_consistency_level) + default_profile = c.profile_manager.default + self.assertIsNone(default_profile.serial_consistency_level) - sentinel = 1001 - for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL, sentinel): - s.default_serial_consistency_level = cl + for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): + s.get_execution_profile(EXEC_PROFILE_DEFAULT).serial_consistency_level = cl # default is passed through - f = s._create_response_future(query='', parameters=[], trace=False, custom_payload={}, timeout=100) + f = s.execute_async(query='') self.assertEqual(f.message.serial_consistency_level, cl) # any non-None statement setting takes precedence for cl_override in (ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): - f = s._create_response_future(SimpleStatement(query_string='', serial_consistency_level=cl_override), parameters=[], trace=False, custom_payload={}, timeout=100) + f = s.execute_async(SimpleStatement(query_string='', serial_consistency_level=cl_override)) + self.assertEqual(default_profile.serial_consistency_level, cl) + self.assertEqual(f.message.serial_consistency_level, cl_override) + + @mock_session_pools + def test_default_serial_consistency_level_legacy(self, *_): + """ + Make sure default_serial_consistency_level passes through to a query message using legacy settings. + Also make sure Statement.serial_consistency_level overrides the default. + + PR #510 + """ + c = Cluster(protocol_version=4) + s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) + + # default is None + self.assertIsNone(s.default_serial_consistency_level) + + # Should fail + with self.assertRaises(ValueError): + s.default_serial_consistency_level = ConsistencyLevel.ANY + with self.assertRaises(ValueError): + s.default_serial_consistency_level = 1001 + + for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): + s.default_serial_consistency_level = cl + + # any non-None statement setting takes precedence + for cl_override in (ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): + f = s.execute_async(SimpleStatement(query_string='', serial_consistency_level=cl_override)) self.assertEqual(s.default_serial_consistency_level, cl) self.assertEqual(f.message.serial_consistency_level, cl_override) + + +class ProtocolVersionTests(unittest.TestCase): + + def test_protocol_downgrade_test(self): + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V2) + self.assertEqual(ProtocolVersion.DSE_V1, lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V1) + self.assertEqual(ProtocolVersion.V5,lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V5) + self.assertEqual(ProtocolVersion.V4,lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V4) + self.assertEqual(ProtocolVersion.V3,lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V3) + self.assertEqual(ProtocolVersion.V2,lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V2) + self.assertEqual(ProtocolVersion.V1, lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V1) + self.assertEqual(0, lower) + + self.assertTrue(ProtocolVersion.uses_error_code_map(ProtocolVersion.DSE_V1)) + self.assertTrue(ProtocolVersion.uses_int_query_flags(ProtocolVersion.DSE_V1)) + + self.assertFalse(ProtocolVersion.uses_error_code_map(ProtocolVersion.V4)) + self.assertFalse(ProtocolVersion.uses_int_query_flags(ProtocolVersion.V4)) + + +class ExecutionProfileTest(unittest.TestCase): + def setUp(self): + if connection_class is None: + raise unittest.SkipTest('libev does not appear to be installed correctly') + connection_class.initialize_reactor() + + def _verify_response_future_profile(self, rf, prof): + self.assertEqual(rf._load_balancer, prof.load_balancing_policy) + self.assertEqual(rf._retry_policy, prof.retry_policy) + self.assertEqual(rf.message.consistency_level, prof.consistency_level) + self.assertEqual(rf.message.serial_consistency_level, prof.serial_consistency_level) + self.assertEqual(rf.timeout, prof.request_timeout) + self.assertEqual(rf.row_factory, prof.row_factory) + + @mock_session_pools + def test_default_exec_parameters(self): + cluster = Cluster() + self.assertEqual(cluster._config_mode, _ConfigMode.UNCOMMITTED) + self.assertEqual(cluster.load_balancing_policy.__class__, default_lbp_factory().__class__) + self.assertEqual(cluster.profile_manager.default.load_balancing_policy.__class__, default_lbp_factory().__class__) + self.assertEqual(cluster.default_retry_policy.__class__, RetryPolicy) + self.assertEqual(cluster.profile_manager.default.retry_policy.__class__, RetryPolicy) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertEqual(session.default_timeout, 10.0) + self.assertEqual(cluster.profile_manager.default.request_timeout, 10.0) + self.assertEqual(session.default_consistency_level, ConsistencyLevel.LOCAL_ONE) + self.assertEqual(cluster.profile_manager.default.consistency_level, ConsistencyLevel.LOCAL_ONE) + self.assertEqual(session.default_serial_consistency_level, None) + self.assertEqual(cluster.profile_manager.default.serial_consistency_level, None) + self.assertEqual(session.row_factory, named_tuple_factory) + self.assertEqual(cluster.profile_manager.default.row_factory, named_tuple_factory) + + @mock_session_pools + def test_default_legacy(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) + self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + session.default_timeout = 3.7 + session.default_consistency_level = ConsistencyLevel.ALL + session.default_serial_consistency_level = ConsistencyLevel.SERIAL + rf = session.execute_async("query") + expected_profile = ExecutionProfile(cluster.load_balancing_policy, cluster.default_retry_policy, + session.default_consistency_level, session.default_serial_consistency_level, + session.default_timeout, session.row_factory) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_default_profile(self): + non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'non-default': non_default_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + default_profile = cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT] + rf = session.execute_async("query") + self._verify_response_future_profile(rf, default_profile) + + rf = session.execute_async("query", execution_profile='non-default') + self._verify_response_future_profile(rf, non_default_profile) + + for name, ep in cluster.profile_manager.profiles.items(): + self.assertEqual(ep, session.get_execution_profile(name)) + + # invalid ep + with self.assertRaises(ValueError): + session.get_execution_profile('non-existent') + + def test_serial_consistency_level_validation(self): + # should pass + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.SERIAL) + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) + + # should not pass + with self.assertRaises(ValueError): + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.ANY) + with self.assertRaises(ValueError): + ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=42) + + @mock_session_pools + def test_statement_params_override_legacy(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) + self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) + my_timeout = 1.1234 + + self.assertNotEqual(ss.retry_policy.__class__, cluster.default_retry_policy) + self.assertNotEqual(ss.consistency_level, session.default_consistency_level) + self.assertNotEqual(ss._serial_consistency_level, session.default_serial_consistency_level) + self.assertNotEqual(my_timeout, session.default_timeout) + + rf = session.execute_async(ss, timeout=my_timeout) + expected_profile = ExecutionProfile(load_balancing_policy=cluster.load_balancing_policy, retry_policy=ss.retry_policy, + request_timeout=my_timeout, consistency_level=ss.consistency_level, + serial_consistency_level=ss._serial_consistency_level) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_statement_params_override_profile(self): + non_default_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'non-default': non_default_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + rf = session.execute_async("query", execution_profile='non-default') + + ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), + consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) + my_timeout = 1.1234 + + self.assertNotEqual(ss.retry_policy.__class__, rf._load_balancer.__class__) + self.assertNotEqual(ss.consistency_level, rf.message.consistency_level) + self.assertNotEqual(ss._serial_consistency_level, rf.message.serial_consistency_level) + self.assertNotEqual(my_timeout, rf.timeout) + + rf = session.execute_async(ss, timeout=my_timeout, execution_profile='non-default') + expected_profile = ExecutionProfile(non_default_profile.load_balancing_policy, ss.retry_policy, + ss.consistency_level, ss._serial_consistency_level, my_timeout, non_default_profile.row_factory) + self._verify_response_future_profile(rf, expected_profile) + + @mock_session_pools + def test_no_profile_with_legacy(self): + # don't construct with both + self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), execution_profiles={'a': ExecutionProfile()}) + self.assertRaises(ValueError, Cluster, default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) + self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), + default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) + + # can't add after + cluster = Cluster(load_balancing_policy=RoundRobinPolicy()) + self.assertRaises(ValueError, cluster.add_execution_profile, 'name', ExecutionProfile()) + + # session settings lock out profiles + cluster = Cluster() + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + for attr, value in (('default_timeout', 1), + ('default_consistency_level', ConsistencyLevel.ANY), + ('default_serial_consistency_level', ConsistencyLevel.SERIAL), + ('row_factory', tuple_factory)): + cluster._config_mode = _ConfigMode.UNCOMMITTED + setattr(session, attr, value) + self.assertRaises(ValueError, cluster.add_execution_profile, 'name' + attr, ExecutionProfile()) + + # don't accept profile + self.assertRaises(ValueError, session.execute_async, "query", execution_profile='some name here') + + @mock_session_pools + def test_no_legacy_with_profile(self): + cluster_init = Cluster(execution_profiles={'name': ExecutionProfile()}) + cluster_add = Cluster() + cluster_add.add_execution_profile('name', ExecutionProfile()) + # for clusters with profiles added either way... + for cluster in (cluster_init, cluster_init): + # don't allow legacy parameters set + for attr, value in (('default_retry_policy', RetryPolicy()), + ('load_balancing_policy', default_lbp_factory())): + self.assertRaises(ValueError, setattr, cluster, attr, value) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + for attr, value in (('default_timeout', 1), + ('default_consistency_level', ConsistencyLevel.ANY), + ('default_serial_consistency_level', ConsistencyLevel.SERIAL), + ('row_factory', tuple_factory)): + self.assertRaises(ValueError, setattr, session, attr, value) + + @mock_session_pools + def test_profile_name_value(self): + + internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + cluster = Cluster(execution_profiles={'by-name': internalized_profile}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + + rf = session.execute_async("query", execution_profile='by-name') + self._verify_response_future_profile(rf, internalized_profile) + + by_value = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) + rf = session.execute_async("query", execution_profile=by_value) + self._verify_response_future_profile(rf, by_value) + + @mock_session_pools + def test_exec_profile_clone(self): + + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) + session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) + + profile_attrs = {'request_timeout': 1, + 'consistency_level': ConsistencyLevel.ANY, + 'serial_consistency_level': ConsistencyLevel.SERIAL, + 'row_factory': tuple_factory, + 'retry_policy': RetryPolicy(), + 'load_balancing_policy': default_lbp_factory()} + reference_attributes = ('retry_policy', 'load_balancing_policy') + + # default and one named + for profile in (EXEC_PROFILE_DEFAULT, 'one'): + active = session.get_execution_profile(profile) + clone = session.execution_profile_clone_update(profile) + self.assertIsNot(clone, active) + + all_updated = session.execution_profile_clone_update(clone, **profile_attrs) + self.assertIsNot(all_updated, clone) + for attr, value in profile_attrs.items(): + self.assertEqual(getattr(clone, attr), getattr(active, attr)) + if attr in reference_attributes: + self.assertIs(getattr(clone, attr), getattr(active, attr)) + self.assertNotEqual(getattr(all_updated, attr), getattr(active, attr)) + + # cannot clone nonexistent profile + self.assertRaises(ValueError, session.execution_profile_clone_update, 'DOES NOT EXIST', **profile_attrs) + + def test_no_profiles_same_name(self): + # can override default in init + cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) + + # cannot update default + self.assertRaises(ValueError, cluster.add_execution_profile, EXEC_PROFILE_DEFAULT, ExecutionProfile()) + + # cannot update named init + self.assertRaises(ValueError, cluster.add_execution_profile, 'one', ExecutionProfile()) + + # can add new name + cluster.add_execution_profile('two', ExecutionProfile()) + + # cannot add a profile added dynamically + self.assertRaises(ValueError, cluster.add_execution_profile, 'two', ExecutionProfile()) + + def test_warning_on_no_lbp_with_contact_points_legacy_mode(self): + """ + Test that users are warned when they instantiate a Cluster object in + legacy mode with contact points but no load-balancing policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result logs + + @test_category configuration + """ + self._check_warning_on_no_lbp_with_contact_points( + cluster_kwargs={'contact_points': ['127.0.0.1']} + ) + + def test_warning_on_no_lbp_with_contact_points_profile_mode(self): + """ + Test that users are warned when they instantiate a Cluster object in + execution profile mode with contact points but no load-balancing + policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result logs + + @test_category configuration + """ + self._check_warning_on_no_lbp_with_contact_points(cluster_kwargs={ + 'contact_points': ['127.0.0.1'], + 'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile()} + }) + + @mock_session_pools + def _check_warning_on_no_lbp_with_contact_points(self, cluster_kwargs): + with patch('cassandra.cluster.log') as patched_logger: + Cluster(**cluster_kwargs) + patched_logger.warning.assert_called_once() + warning_message = patched_logger.warning.call_args[0][0] + self.assertIn('please specify a load-balancing policy', warning_message) + self.assertIn("contact_points = ['127.0.0.1']", warning_message) + + def test_no_warning_on_contact_points_with_lbp_legacy_mode(self): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy in legacy mode. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + self._check_no_warning_on_contact_points_with_lbp({ + 'contact_points': ['127.0.0.1'], + 'load_balancing_policy': object() + }) + + def test_no_warning_on_contact_points_with_lbp_profiles_mode(self): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy in execution profile + mode. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + self._check_no_warning_on_contact_points_with_lbp(cluster_kwargs={ + 'contact_points': ['127.0.0.1'], + 'execution_profiles': { + EXEC_PROFILE_DEFAULT: ep_with_lbp + } + }) + + @mock_session_pools + def _check_no_warning_on_contact_points_with_lbp(self, cluster_kwargs): + """ + Test that users aren't warned when they instantiate a Cluster object + with contact points and a load-balancing policy. + + @since 3.12.0 + @jira_ticket PYTHON-812 + @expected_result no logs + + @test_category configuration + """ + with patch('cassandra.cluster.log') as patched_logger: + Cluster(**cluster_kwargs) + patched_logger.warning.assert_not_called() + + @mock_session_pools + def test_warning_adding_no_lbp_ep_to_cluster_with_contact_points(self): + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + cluster = Cluster( + contact_points=['127.0.0.1'], + execution_profiles={EXEC_PROFILE_DEFAULT: ep_with_lbp}) + with patch('cassandra.cluster.log') as patched_logger: + cluster.add_execution_profile( + name='no_lbp', + profile=ExecutionProfile() + ) + + patched_logger.warning.assert_called_once() + warning_message = patched_logger.warning.call_args[0][0] + self.assertIn('no_lbp', warning_message) + self.assertIn('trying to add', warning_message) + self.assertIn('please specify a load-balancing policy', warning_message) + + @mock_session_pools + def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self): + ep_with_lbp = ExecutionProfile(load_balancing_policy=object()) + cluster = Cluster( + contact_points=['127.0.0.1'], + execution_profiles={EXEC_PROFILE_DEFAULT: ep_with_lbp}) + with patch('cassandra.cluster.log') as patched_logger: + cluster.add_execution_profile( + name='with_lbp', + profile=ExecutionProfile(load_balancing_policy=Mock(name='lbp')) + ) + + patched_logger.warning.assert_not_called() diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index a1f953220f..18e8381185 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +15,21 @@ # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest + from itertools import cycle -from mock import Mock +from unittest.mock import Mock import time import threading -from six.moves.queue import PriorityQueue +from queue import PriorityQueue +import sys +import platform -from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args +from cassandra.cluster import Cluster, Session +from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args, execute_concurrent_async +from cassandra.pool import Host +from cassandra.policies import SimpleConvictionPolicy +from tests.unit.utils import mock_session_pools class MockResponseResponseFuture(): @@ -34,8 +40,9 @@ class MockResponseResponseFuture(): _query_trace = None _col_names = None + _col_types = None - # a list pending callbacks, these will be prioritized in reverse or normal orderd + # a list pending callbacks, these will be prioritized in reverse or normal order pending_callbacks = PriorityQueue() def __init__(self, reverse): @@ -111,7 +118,6 @@ def run(self): self._stopper.wait(.001) return - class ConcurrencyTest((unittest.TestCase)): def test_results_ordering_forward(self): @@ -175,7 +181,7 @@ def insert_and_validate_list_results(self, reverse, slowdown): This utility method will execute submit various statements for execution using the ConcurrentExecutorListResults, then invoke a separate thread to execute the callback associated with the futures registered for those statements. The parameters will toggle various timing, and ordering changes. - Finally it will validate that the results were returned in the order they were submitted + Finally, it will validate that the results were returned in the order they were submitted :param reverse: Execute the callbacks in the opposite order that they were submitted :param slowdown: Cause intermittent queries to perform slowly """ @@ -199,7 +205,7 @@ def insert_and_validate_list_generator(self, reverse, slowdown): This utility method will execute submit various statements for execution using the ConcurrentExecutorGenResults, then invoke a separate thread to execute the callback associated with the futures registered for those statements. The parameters will toggle various timing, and ordering changes. - Finally it will validate that the results were returned in the order they were submitted + Finally, it will validate that the results were returned in the order they were submitted :param reverse: Execute the callbacks in the opposite order that they were submitted :param slowdown: Cause intermittent queries to perform slowly """ @@ -211,10 +217,11 @@ def insert_and_validate_list_generator(self, reverse, slowdown): t = TimedCallableInvoker(our_handler, slowdown=slowdown) t.start() - results = execute_concurrent(mock_session, statements_and_params, results_generator=True) - - self.validate_result_ordering(results) - t.stop() + try: + results = execute_concurrent(mock_session, statements_and_params, results_generator=True) + self.validate_result_ordering(results) + finally: + t.stop() def validate_result_ordering(self, results): """ @@ -226,5 +233,79 @@ def validate_result_ordering(self, results): for success, result in results: self.assertTrue(success) current_time_added = list(result)[0] - self.assertLess(last_time_added, current_time_added) + + # Windows clock granularity makes this equal most of the time + if "Windows" in platform.system(): + self.assertLessEqual(last_time_added, current_time_added) + else: + self.assertLess(last_time_added, current_time_added) last_time_added = current_time_added + + def insert_and_validate_list_async(self, reverse, slowdown): + """ + This utility method will execute submit various statements for execution using execute_concurrent_async, + then invoke a separate thread to execute the callback associated with the futures registered + for those statements. The parameters will toggle various timing, and ordering changes. + Finally it will validate that the results were returned in the order they were submitted + :param reverse: Execute the callbacks in the opposite order that they were submitted + :param slowdown: Cause intermittent queries to perform slowly + """ + our_handler = MockResponseResponseFuture(reverse=reverse) + mock_session = Mock() + statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), + [(i, ) for i in range(100)]) + mock_session.execute_async.return_value = our_handler + + t = TimedCallableInvoker(our_handler, slowdown=slowdown) + t.start() + try: + future = execute_concurrent_async(mock_session, statements_and_params) + results = future.result() + self.validate_result_ordering(results) + finally: + t.stop() + + def test_results_ordering_async_forward(self): + """ + This tests the ordering of our execute_concurrent_async function + when queries complete in the order they were executed. + """ + self.insert_and_validate_list_async(False, False) + + def test_results_ordering_async_reverse(self): + """ + This tests the ordering of our execute_concurrent_async function + when queries complete in the reverse order they were executed. + """ + self.insert_and_validate_list_async(True, False) + + def test_results_ordering_async_forward_slowdown(self): + """ + This tests the ordering of our execute_concurrent_async function + when queries complete in the order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_async(False, True) + + def test_results_ordering_async_reverse_slowdown(self): + """ + This tests the ordering of our execute_concurrent_async function + when queries complete in the reverse order they were executed, with slow queries mixed in. + """ + self.insert_and_validate_list_async(True, True) + + @mock_session_pools + def test_recursion_limited(self): + """ + Verify that recursion is controlled when raise_on_first_error=False and something is wrong with the query. + + PYTHON-585 + """ + max_recursion = sys.getrecursionlimit() + s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) + self.assertRaises(TypeError, execute_concurrent_with_args, s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) + + results = execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=False) # previously + self.assertEqual(len(results), max_recursion) + for r in results: + self.assertFalse(r[0]) + self.assertIsInstance(r[1], TypeError) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 15fa6e722b..3bca654c55 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -1,53 +1,54 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from mock import Mock, ANY, call, patch -import six -from six import BytesIO +import unittest +from io import BytesIO import time from threading import Lock +from unittest.mock import Mock, ANY, call, patch -from cassandra.cluster import Cluster, Session +from cassandra import OperationTimedOut +from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, - locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager) + locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, + ConnectionException, DefaultEndPoint) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) +from tests.util import wait_until + class ConnectionTest(unittest.TestCase): def make_connection(self): - c = Connection('1.2.3.4') + c = Connection(DefaultEndPoint('1.2.3.4')) c._socket = Mock() c._socket.send.side_effect = lambda x: len(x) return c def make_header_prefix(self, message_class, version=Connection.protocol_version, stream_id=0): if Connection.protocol_version < 3: - return six.binary_type().join(map(uint8_pack, [ + return bytes().join(map(uint8_pack, [ 0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0, # flags (compression) stream_id, message_class.opcode # opcode ])) else: - return six.binary_type().join(map(uint8_pack, [ + return bytes().join(map(uint8_pack, [ 0xff & (HEADER_DIRECTION_TO_CLIENT | version), 0, # flags (compression) 0, # MSB for v3+ stream @@ -55,7 +56,6 @@ def make_header_prefix(self, message_class, version=Connection.protocol_version, message_class.opcode # opcode ])) - def make_options_body(self): options_buf = BytesIO() write_stringmultimap(options_buf, { @@ -73,6 +73,21 @@ def make_error_body(self, code, msg): def make_msg(self, header, body=""): return header + uint32_pack(len(body)) + body + def test_connection_endpoint(self): + endpoint = DefaultEndPoint('1.2.3.4') + c = Connection(endpoint) + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + + c = Connection(host=endpoint) # kwarg + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + + c = Connection('10.0.0.1') + endpoint = DefaultEndPoint('10.0.0.1') + self.assertEqual(c.endpoint, endpoint) + self.assertEqual(c.endpoint.address, endpoint.address) + def test_bad_protocol_version(self, *args): c = self.make_connection() c._requests = Mock() @@ -82,7 +97,7 @@ def test_bad_protocol_version(self, *args): header = self.make_header_prefix(SupportedMessage, version=0x7f) options = self.make_options_body() message = self.make_msg(header, options) - c._iobuf = BytesIO() + c._iobuf._io_buffer = BytesIO() c._iobuf.write(message) c.process_io_buffer() @@ -99,7 +114,7 @@ def test_negative_body_length(self, *args): # read in a SupportedMessage response header = self.make_header_prefix(SupportedMessage) message = header + int32_pack(-13) - c._iobuf = BytesIO() + c._iobuf._io_buffer = BytesIO() c._iobuf.write(message) c.process_io_buffer() @@ -110,13 +125,10 @@ def test_negative_body_length(self, *args): def test_unsupported_cql_version(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() c.cql_version = "3.0.3" - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - options_buf = BytesIO() write_stringmultimap(options_buf, { 'CQL_VERSION': ['7.8.9'], @@ -133,7 +145,7 @@ def test_unsupported_cql_version(self, *args): def test_prefer_lz4_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() c.cql_version = "3.0.3" @@ -156,7 +168,7 @@ def test_prefer_lz4_compression(self, *args): def test_requested_compression_not_available(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() # request lz4 compression c.compression = "lz4" @@ -166,9 +178,6 @@ def test_requested_compression_not_available(self, *args): locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - # the server only supports snappy options_buf = BytesIO() write_stringmultimap(options_buf, { @@ -186,7 +195,7 @@ def test_requested_compression_not_available(self, *args): def test_use_requested_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() # request snappy compression c.compression = "snappy" @@ -196,9 +205,6 @@ def test_use_requested_compression(self, *args): locally_supported_compressions['lz4'] = ('lz4compress', 'lz4decompress') locally_supported_compressions['snappy'] = ('snappycompress', 'snappydecompress') - # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage) - # the server only supports snappy options_buf = BytesIO() write_stringmultimap(options_buf, { @@ -275,9 +281,11 @@ def make_get_holders(len): get_holders = Mock(return_value=holders) return get_holders - def run_heartbeat(self, get_holders_fun, count=2, interval=0.05): - ch = ConnectionHeartbeat(interval, get_holders_fun) - time.sleep(interval * count) + def run_heartbeat(self, get_holders_fun, count=2, interval=0.05, timeout=0.05): + ch = ConnectionHeartbeat(interval, get_holders_fun, timeout=timeout) + # wait until the thread is started + wait_until(lambda: get_holders_fun.call_count > 0, 0.01, 100) + time.sleep(interval * (count-1)) ch.stop() self.assertTrue(get_holders_fun.call_count) @@ -287,7 +295,7 @@ def test_empty_connections(self, *args): self.run_heartbeat(get_holders, count) - self.assertGreaterEqual(get_holders.call_count, count - 1) # lower bound to account for thread spinup time + self.assertGreaterEqual(get_holders.call_count, count-1) self.assertLessEqual(get_holders.call_count, count) holder = get_holders.return_value[0] holder.get_connections.assert_has_calls([call()] * get_holders.call_count) @@ -344,7 +352,7 @@ def test_no_req_ids(self, *args): get_holders = self.make_get_holders(1) max_connection = Mock(spec=Connection, host='localhost', lock=Lock(), - max_request_id=in_flight, in_flight=in_flight, + max_request_id=in_flight - 1, in_flight=in_flight, is_idle=True, is_defunct=False, is_closed=False) holder = get_holders.return_value[0] holder.get_connections.return_value.append(max_connection) @@ -356,7 +364,8 @@ def test_no_req_ids(self, *args): self.assertEqual(max_connection.send_msg.call_count, 0) self.assertEqual(max_connection.send_msg.call_count, 0) max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) - holder.return_connection.assert_has_calls([call(max_connection)] * get_holders.call_count) + holder.return_connection.assert_has_calls( + [call(max_connection)] * get_holders.call_count) def test_unexpected_response(self, *args): request_id = 999 @@ -382,9 +391,10 @@ def send_msg(msg, req_id, msg_callback): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) - holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) + self.assertIsInstance(exc, ConnectionException) + self.assertRegex(exc.args[0], r'^Received unexpected response to OptionsMessage.*') + holder.return_connection.assert_has_calls( + [call(connection)] * get_holders.call_count) def test_timeout(self, *args): request_id = 999 @@ -394,7 +404,8 @@ def test_timeout(self, *args): def send_msg(msg, req_id, msg_callback): pass - connection = Mock(spec=Connection, host='localhost', + # we used endpoint=X here because it's a mock and we need connection.endpoint to be set + connection = Mock(spec=Connection, endpoint=DefaultEndPoint('localhost'), max_request_id=127, lock=Lock(), in_flight=0, is_idle=True, @@ -410,9 +421,11 @@ def send_msg(msg, req_id, msg_callback): connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, Exception) - self.assertEqual(exc.args, Exception('Connection heartbeat failure').args) - holder.return_connection.assert_has_calls([call(connection)] * get_holders.call_count) + self.assertIsInstance(exc, OperationTimedOut) + self.assertEqual(exc.errors, 'Connection heartbeat timeout after 0.05 seconds') + self.assertEqual(exc.last_host, DefaultEndPoint('localhost')) + holder.return_connection.assert_has_calls( + [call(connection)] * get_holders.call_count) class TimerTest(unittest.TestCase): @@ -429,3 +442,49 @@ def test_timer_collision(self): tm.add_timer(t2) # Prior to #466: "TypeError: unorderable types: Timer() < Timer()" tm.service_timeouts() + + +class DefaultEndPointTest(unittest.TestCase): + + def test_default_endpoint_properties(self): + endpoint = DefaultEndPoint('10.0.0.1') + self.assertEqual(endpoint.address, '10.0.0.1') + self.assertEqual(endpoint.port, 9042) + self.assertEqual(str(endpoint), '10.0.0.1:9042') + + endpoint = DefaultEndPoint('10.0.0.1', 8888) + self.assertEqual(endpoint.address, '10.0.0.1') + self.assertEqual(endpoint.port, 8888) + self.assertEqual(str(endpoint), '10.0.0.1:8888') + + def test_endpoint_equality(self): + self.assertEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1') + ) + + self.assertEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1', 9042) + ) + + self.assertNotEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.2') + ) + + self.assertNotEqual( + DefaultEndPoint('10.0.0.1'), + DefaultEndPoint('10.0.0.1', 0000) + ) + + def test_endpoint_resolve(self): + self.assertEqual( + DefaultEndPoint('10.0.0.1').resolve(), + ('10.0.0.1', 9042) + ) + + self.assertEqual( + DefaultEndPoint('10.0.0.1', 3232).resolve(), + ('10.0.0.1', 3232) + ) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 9fac7e2f46..618bb42b1f 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,20 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from concurrent.futures import ThreadPoolExecutor -from mock import Mock, ANY, call +from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS -from cassandra.cluster import ControlConnection, _Scheduler +from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host +from cassandra.connection import EndPoint, DefaultEndPoint, DefaultEndPointFactory from cassandra.policies import (SimpleConvictionPolicy, RoundRobinPolicy, - ConstantReconnectionPolicy) + ConstantReconnectionPolicy, IdentityTranslator) PEER_IP = "foobar" @@ -34,19 +34,26 @@ class MockMetadata(object): def __init__(self): self.hosts = { - "192.168.1.0": Host("192.168.1.0", SimpleConvictionPolicy), - "192.168.1.1": Host("192.168.1.1", SimpleConvictionPolicy), - "192.168.1.2": Host("192.168.1.2", SimpleConvictionPolicy) + DefaultEndPoint("192.168.1.0"): Host(DefaultEndPoint("192.168.1.0"), SimpleConvictionPolicy), + DefaultEndPoint("192.168.1.1"): Host(DefaultEndPoint("192.168.1.1"), SimpleConvictionPolicy), + DefaultEndPoint("192.168.1.2"): Host(DefaultEndPoint("192.168.1.2"), SimpleConvictionPolicy) } for host in self.hosts.values(): host.set_up() + host.release_version = "3.11" self.cluster_name = None self.partitioner = None self.token_map = {} - def get_host(self, rpc_address): - return self.hosts.get(rpc_address) + def get_host(self, endpoint_or_address, port=None): + if not isinstance(endpoint_or_address, EndPoint): + for host in self.hosts.values(): + if (host.address == endpoint_or_address and + (port is None or host.broadcast_rpc_port is None or host.broadcast_rpc_port == port)): + return host + else: + return self.hosts.get(endpoint_or_address) def all_hosts(self): return self.hosts.values() @@ -59,8 +66,9 @@ def rebuild_token_map(self, partitioner, token_map): class MockCluster(object): max_schema_agreement_wait = 5 - load_balancing_policy = RoundRobinPolicy() + profile_manager = ProfileManager() reconnection_policy = ConstantReconnectionPolicy(2) + address_translator = IdentityTranslator() down_host = None contact_points = [] is_shutdown = False @@ -71,11 +79,13 @@ def __init__(self): self.removed_hosts = [] self.scheduler = Mock(spec=_Scheduler) self.executor = Mock(spec=ThreadPoolExecutor) + self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(RoundRobinPolicy()) + self.endpoint_factory = DefaultEndPointFactory().configure(self) - def add_host(self, address, datacenter, rack, signal=False, refresh_nodes=True): - host = Host(address, SimpleConvictionPolicy, datacenter, rack) + def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True): + host = Host(endpoint, SimpleConvictionPolicy, datacenter, rack) self.added_hosts.append(host) - return host + return host, True def remove_host(self, host): self.removed_hosts.append(host) @@ -87,28 +97,44 @@ def on_down(self, host, is_host_addition): self.down_host = host +def _node_meta_results(local_results, peer_results): + """ + creates a pair of ResultMessages from (col_names, parsed_rows) + """ + local_response = ResultMessage(kind=RESULT_KIND_ROWS) + local_response.column_names = local_results[0] + local_response.parsed_rows = local_results[1] + + peer_response = ResultMessage(kind=RESULT_KIND_ROWS) + peer_response.column_names = peer_results[0] + peer_response.parsed_rows = peer_results[1] + + return peer_response, local_response + + class MockConnection(object): is_defunct = False def __init__(self): - self.host = "192.168.1.0" + self.endpoint = DefaultEndPoint("192.168.1.0") self.local_results = [ ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] ] self.peer_results = [ - ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], - [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], - ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]] + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid1"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"]] ] - local_response = ResultMessage( - kind=RESULT_KIND_ROWS, results=self.local_results) - peer_response = ResultMessage( - kind=RESULT_KIND_ROWS, results=self.peer_results) - self.wait_for_responses = Mock(return_value=(peer_response, local_response)) + self.peer_results_v2 = [ + ["native_address", "native_port", "peer", "peer_port", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.1", 9042, "10.0.0.1", 7042, "a", "dc1", "rack1", ["1", "101", "201"], "uuid1"], + ["192.168.1.2", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"]] + ] + self.wait_for_responses = Mock(return_value=_node_meta_results(self.local_results, self.peer_results)) class FakeTime(object): @@ -125,47 +151,29 @@ def sleep(self, amount): class ControlConnectionTest(unittest.TestCase): + _matching_schema_preloaded_results = _node_meta_results( + local_results=(["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"], "uuid1"]]), + peer_results=(["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"], + ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"], "uuid3"]])) + + _nonmatching_schema_preloaded_results = _node_meta_results( + local_results=(["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens", "host_id"], + [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"], "uuid1"]]), + peer_results=(["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], "uuid2"], + ["192.168.1.2", "10.0.0.2", "b", "dc1", "rack1", ["2", "102", "202"], "uuid3"]])) + def setUp(self): self.cluster = MockCluster() self.connection = MockConnection() self.time = FakeTime() - self.control_connection = ControlConnection(self.cluster, 1, 0, 0) + self.control_connection = ControlConnection(self.cluster, 1, 0, 0, 0) self.control_connection._connection = self.connection self.control_connection._time = self.time - def _get_matching_schema_preloaded_results(self): - local_results = [ - ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], - [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] - ] - local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results) - - peer_results = [ - ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], - [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], - ["192.168.1.2", "10.0.0.2", "a", "dc1", "rack1", ["2", "102", "202"]]] - ] - peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results) - - return (peer_response, local_response) - - def _get_nonmatching_schema_preloaded_results(self): - local_results = [ - ["schema_version", "cluster_name", "data_center", "rack", "partitioner", "release_version", "tokens"], - [["a", "foocluster", "dc1", "rack1", "Murmur3Partitioner", "2.2.0", ["0", "100", "200"]]] - ] - local_response = ResultMessage(kind=RESULT_KIND_ROWS, results=local_results) - - peer_results = [ - ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens"], - [["192.168.1.1", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"]], - ["192.168.1.2", "10.0.0.2", "b", "dc1", "rack1", ["2", "102", "202"]]] - ] - peer_response = ResultMessage(kind=RESULT_KIND_ROWS, results=peer_results) - - return (peer_response, local_response) - def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing @@ -178,8 +186,7 @@ def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): """ wait_for_schema_agreement uses preloaded results if given for shared table queries """ - preloaded_results = self._get_matching_schema_preloaded_results() - + preloaded_results = self._matching_schema_preloaded_results self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) # the control connection should not have slept at all self.assertEqual(self.time.clock, 0) @@ -190,8 +197,7 @@ def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_ """ wait_for_schema_agreement requery if schema does not match using preloaded results """ - preloaded_results = self._get_nonmatching_schema_preloaded_results() - + preloaded_results = self._nonmatching_schema_preloaded_results self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) # the control connection should not have slept at all self.assertEqual(self.time.clock, 0) @@ -222,7 +228,7 @@ def test_wait_for_schema_agreement_skipping(self): # change the schema version on one of the existing entries self.connection.peer_results[1][1][3] = 'c' - self.cluster.metadata.get_host('192.168.1.1').is_up = False + self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False self.assertTrue(self.control_connection.wait_for_schema_agreement()) self.assertEqual(self.time.clock, 0) @@ -234,8 +240,8 @@ def test_wait_for_schema_agreement_rpc_lookup(self): self.connection.peer_results[1].append( ["0.0.0.0", PEER_IP, "b", "dc1", "rack1", ["3", "103", "203"]] ) - host = Host("0.0.0.0", SimpleConvictionPolicy) - self.cluster.metadata.hosts[PEER_IP] = host + host = Host(DefaultEndPoint("0.0.0.0"), SimpleConvictionPolicy) + self.cluster.metadata.hosts[DefaultEndPoint("foobar")] = host host.is_up = False # even though the new host has a different schema version, it's @@ -266,12 +272,45 @@ def test_refresh_nodes_and_tokens(self): self.assertEqual(self.connection.wait_for_responses.call_count, 1) + def test_refresh_nodes_and_tokens_with_invalid_peers(self): + def refresh_and_validate_added_hosts(): + self.connection.wait_for_responses = Mock(return_value=_node_meta_results( + self.connection.local_results, self.connection.peer_results)) + self.control_connection.refresh_node_list_and_token_map() + self.assertEqual(1, len(self.cluster.added_hosts)) # only one valid peer found + + # peersV1 + del self.connection.peer_results[:] + self.connection.peer_results.extend([ + ["rpc_address", "peer", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.3", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], 'uuid5'], + # all others are invalid + [None, None, "a", "dc1", "rack1", ["1", "101", "201"], 'uuid1'], + ["192.168.1.7", "10.0.0.1", "a", None, "rack1", ["1", "101", "201"], 'uuid2'], + ["192.168.1.6", "10.0.0.1", "a", "dc1", None, ["1", "101", "201"], 'uuid3'], + ["192.168.1.5", "10.0.0.1", "a", "dc1", "rack1", None, 'uuid4'], + ["192.168.1.4", "10.0.0.1", "a", "dc1", "rack1", ["1", "101", "201"], None]]]) + refresh_and_validate_added_hosts() + + # peersV2 + del self.cluster.added_hosts[:] + del self.connection.peer_results[:] + self.connection.peer_results.extend([ + ["native_address", "native_port", "peer", "peer_port", "schema_version", "data_center", "rack", "tokens", "host_id"], + [["192.168.1.4", 9042, "10.0.0.1", 7042, "a", "dc1", "rack1", ["1", "101", "201"], "uuid1"], + # all others are invalid + [None, 9042, None, 7040, "a", "dc1", "rack1", ["2", "102", "202"], "uuid2"], + ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", None, "rack1", ["2", "102", "202"], "uuid2"], + ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", None, ["2", "102", "202"], "uuid2"], + ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", None, "uuid2"], + ["192.168.1.5", 9042, "10.0.0.2", 7040, "a", "dc1", "rack1", ["2", "102", "202"], None]]]) + refresh_and_validate_added_hosts() + def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): """ refresh_nodes_and_tokens uses preloaded results if given for shared table queries """ - preloaded_results = self._get_matching_schema_preloaded_results() - + preloaded_results = self._matching_schema_preloaded_results self.control_connection._refresh_node_list_and_token_map(self.connection, preloaded_results=preloaded_results) meta = self.cluster.metadata self.assertEqual(meta.partitioner, 'Murmur3Partitioner') @@ -303,7 +342,7 @@ def test_refresh_nodes_and_tokens_no_partitioner(self): def test_refresh_nodes_and_tokens_add_host(self): self.connection.peer_results[1].append( - ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"]] + ["192.168.1.3", "10.0.0.3", "a", "dc1", "rack1", ["3", "103", "203"], "uuid3"] ) self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() @@ -311,6 +350,7 @@ def test_refresh_nodes_and_tokens_add_host(self): self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3") self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + self.assertEqual(self.cluster.added_hosts[0].host_id, "uuid3") def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] @@ -346,7 +386,8 @@ def test_handle_topology_change(self): } self.cluster.scheduler.reset_mock() self.control_connection._handle_topology_change(event) - self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) + + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection._refresh_nodes_if_not_up, None) event = { 'change_type': 'REMOVED_NODE', @@ -362,7 +403,7 @@ def test_handle_topology_change(self): } self.cluster.scheduler.reset_mock() self.control_connection._handle_topology_change(event) - self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection.refresh_node_list_and_token_map) + self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.control_connection._refresh_nodes_if_not_up, None) def test_handle_status_change(self): event = { @@ -376,11 +417,11 @@ def test_handle_status_change(self): # do the same with a known Host event = { 'change_type': 'UP', - 'address': ('192.168.1.0', 9000) + 'address': ('192.168.1.0', 9042) } self.cluster.scheduler.reset_mock() self.control_connection._handle_status_change(event) - host = self.cluster.metadata.hosts['192.168.1.0'] + host = self.cluster.metadata.hosts[DefaultEndPoint('192.168.1.0')] self.cluster.scheduler.schedule_unique.assert_called_once_with(ANY, self.cluster.on_up, host) self.cluster.scheduler.schedule.reset_mock() @@ -397,7 +438,7 @@ def test_handle_status_change(self): 'address': ('192.168.1.0', 9000) } self.control_connection._handle_status_change(event) - host = self.cluster.metadata.hosts['192.168.1.0'] + host = self.cluster.metadata.hosts[DefaultEndPoint('192.168.1.0')] self.assertIs(host, self.cluster.down_host) def test_handle_schema_change(self): @@ -440,7 +481,7 @@ def test_refresh_disabled(self): 'address': ('1.2.3.4', 9000) } - cc_no_schema_refresh = ControlConnection(cluster, 1, -1, 0) + cc_no_schema_refresh = ControlConnection(cluster, 1, -1, 0, 0) cluster.scheduler.reset_mock() # no call on schema refresh @@ -452,9 +493,9 @@ def test_refresh_disabled(self): cc_no_schema_refresh._handle_status_change(status_event) cc_no_schema_refresh._handle_topology_change(topo_event) cluster.scheduler.schedule_unique.assert_has_calls([call(ANY, cc_no_schema_refresh.refresh_node_list_and_token_map), - call(ANY, cc_no_schema_refresh.refresh_node_list_and_token_map)]) + call(ANY, cc_no_schema_refresh._refresh_nodes_if_not_up, None)]) - cc_no_topo_refresh = ControlConnection(cluster, 1, 0, -1) + cc_no_topo_refresh = ControlConnection(cluster, 1, 0, -1, 0) cluster.scheduler.reset_mock() # no call on topo refresh @@ -469,6 +510,46 @@ def test_refresh_disabled(self): call(0.0, cc_no_topo_refresh.refresh_schema, **schema_event)]) + def test_refresh_nodes_and_tokens_add_host_detects_port(self): + del self.connection.peer_results[:] + self.connection.peer_results.extend(self.connection.peer_results_v2) + self.connection.peer_results[1].append( + ["192.168.1.3", 555, "10.0.0.3", 666, "a", "dc1", "rack1", ["3", "103", "203"], "uuid3"] + ) + self.connection.wait_for_responses = Mock(return_value=_node_meta_results( + self.connection.local_results, self.connection.peer_results)) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + self.control_connection.refresh_node_list_and_token_map() + self.assertEqual(1, len(self.cluster.added_hosts)) + self.assertEqual(self.cluster.added_hosts[0].endpoint.address, "192.168.1.3") + self.assertEqual(self.cluster.added_hosts[0].endpoint.port, 555) + self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_address, "192.168.1.3") + self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_port, 555) + self.assertEqual(self.cluster.added_hosts[0].broadcast_address, "10.0.0.3") + self.assertEqual(self.cluster.added_hosts[0].broadcast_port, 666) + self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") + self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + + def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): + del self.connection.peer_results[:] + self.connection.peer_results.extend(self.connection.peer_results_v2) + self.connection.peer_results[1].append( + ["192.168.1.3", -1, "10.0.0.3", 0, "a", "dc1", "rack1", ["3", "103", "203"], "uuid3"] + ) + self.connection.wait_for_responses = Mock(return_value=_node_meta_results( + self.connection.local_results, self.connection.peer_results)) + self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) + self.control_connection.refresh_node_list_and_token_map() + self.assertEqual(1, len(self.cluster.added_hosts)) + self.assertEqual(self.cluster.added_hosts[0].endpoint.address, "192.168.1.3") + self.assertEqual(self.cluster.added_hosts[0].endpoint.port, 9042) # fallback default + self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_address, "192.168.1.3") + self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_port, None) + self.assertEqual(self.cluster.added_hosts[0].broadcast_address, "10.0.0.3") + self.assertEqual(self.cluster.added_hosts[0].broadcast_port, None) + self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") + self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + class EventTimingTest(unittest.TestCase): """ @@ -481,7 +562,7 @@ def setUp(self): self.time = FakeTime() # Use 2 for the schema_event_refresh_window which is what we would normally default to. - self.control_connection = ControlConnection(self.cluster, 1, 2, 0) + self.control_connection = ControlConnection(self.cluster, 1, 2, 0, 0) self.control_connection._connection = self.connection self.control_connection._time = self.time diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py new file mode 100644 index 0000000000..4352afb9a5 --- /dev/null +++ b/tests/unit/test_endpoints.py @@ -0,0 +1,79 @@ +# Copyright DataStax, Inc. +# +# Licensed under the DataStax DSE Driver License; +# you may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.datastax.com/terms/datastax-dse-driver-license-terms +import unittest + +import itertools + +from cassandra.connection import DefaultEndPoint, SniEndPoint, SniEndPointFactory + +from unittest.mock import patch + + +def socket_getaddrinfo(*args): + return [ + (0, 0, 0, '', ('127.0.0.1', 30002)), + (0, 0, 0, '', ('127.0.0.2', 30002)), + (0, 0, 0, '', ('127.0.0.3', 30002)) + ] + + +@patch('socket.getaddrinfo', socket_getaddrinfo) +class SniEndPointTest(unittest.TestCase): + + endpoint_factory = SniEndPointFactory("proxy.datastax.com", 30002) + + def test_sni_endpoint_properties(self): + + endpoint = self.endpoint_factory.create_from_sni('test') + self.assertEqual(endpoint.address, 'proxy.datastax.com') + self.assertEqual(endpoint.port, 30002) + self.assertEqual(endpoint._server_name, 'test') + self.assertEqual(str(endpoint), 'proxy.datastax.com:30002:test') + + def test_endpoint_equality(self): + self.assertNotEqual( + DefaultEndPoint('10.0.0.1'), + self.endpoint_factory.create_from_sni('10.0.0.1') + ) + + self.assertEqual( + self.endpoint_factory.create_from_sni('10.0.0.1'), + self.endpoint_factory.create_from_sni('10.0.0.1') + ) + + self.assertNotEqual( + self.endpoint_factory.create_from_sni('10.0.0.1'), + self.endpoint_factory.create_from_sni('10.0.0.0') + ) + + self.assertNotEqual( + self.endpoint_factory.create_from_sni('10.0.0.1'), + SniEndPointFactory("proxy.datastax.com", 9999).create_from_sni('10.0.0.1') + ) + + def test_endpoint_resolve(self): + ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3'] + it = itertools.cycle(ips) + + endpoint = self.endpoint_factory.create_from_sni('test') + for i in range(10): + (address, _) = endpoint.resolve() + self.assertEqual(address, next(it)) + + def test_sni_resolution_start_index(self): + factory = SniEndPointFactory("proxy.datastax.com", 9999) + initial_index = factory._init_index + + endpoint1 = factory.create_from_sni('sni1') + self.assertEqual(factory._init_index, initial_index + 1) + self.assertEqual(endpoint1._index, factory._init_index) + + endpoint2 = factory.create_from_sni('sni2') + self.assertEqual(factory._init_index, initial_index + 2) + self.assertEqual(endpoint2._index, factory._init_index) diff --git a/tests/unit/test_exception.py b/tests/unit/test_exception.py index a88b5260fa..4758970d9c 100644 --- a/tests/unit/test_exception.py +++ b/tests/unit/test_exception.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from cassandra import Unavailable, Timeout, ConsistencyLevel import re diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index fb0ca21711..d8b5ca976e 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,21 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from mock import Mock, NonCallableMagicMock +import unittest from threading import Thread, Event, Lock +from unittest.mock import Mock, NonCallableMagicMock from cassandra.cluster import Session from cassandra.connection import Connection -from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable +from cassandra.pool import HostConnection, HostConnectionPool +from cassandra.pool import Host, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy - -class HostConnectionPoolTests(unittest.TestCase): +class _PoolTests(unittest.TestCase): + __test__ = False + PoolImpl = None + uses_single_connection = None def make_session(self): session = NonCallableMagicMock(spec=Session, keyspace='foobarkeyspace') @@ -41,8 +42,8 @@ def test_borrow_and_return(self): conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) c, request_id = pool.borrow_connection(timeout=0.01) self.assertIs(c, conn) @@ -51,7 +52,8 @@ def test_borrow_and_return(self): pool.return_connection(conn) self.assertEqual(0, conn.in_flight) - self.assertNotIn(conn, pool._trash) + if not self.uses_single_connection: + self.assertNotIn(conn, pool._trash) def test_failed_wait_for_connection(self): host = Mock(spec=Host, address='ip1') @@ -59,8 +61,8 @@ def test_failed_wait_for_connection(self): conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) self.assertEqual(1, conn.in_flight) @@ -77,8 +79,8 @@ def test_successful_wait_for_connection(self): conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, lock=Lock()) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) self.assertEqual(1, conn.in_flight) @@ -95,48 +97,6 @@ def get_second_conn(): t.join() self.assertEqual(0, conn.in_flight) - def test_all_connections_trashed(self): - host = Mock(spec=Host, address='ip1') - session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, lock=Lock()) - session.cluster.connection_factory.return_value = conn - session.cluster.get_core_connections_per_host.return_value = 1 - - # manipulate the core connection setting so that we can - # trash the only connection - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.get_core_connections_per_host.return_value = 0 - pool._maybe_trash_connection(conn) - session.cluster.get_core_connections_per_host.return_value = 1 - - submit_called = Event() - - def fire_event(*args, **kwargs): - submit_called.set() - - session.submit.side_effect = fire_event - - def get_conn(): - conn.reset_mock() - c, request_id = pool.borrow_connection(1.0) - self.assertIs(conn, c) - self.assertEqual(1, conn.in_flight) - conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') - pool.return_connection(c) - - t = Thread(target=get_conn) - t.start() - - submit_called.wait() - self.assertEqual(1, pool._scheduled_for_creation) - session.submit.assert_called_once_with(pool._create_new_connection) - - # now run the create_new_connection call - pool._create_new_connection() - - t.join() - self.assertEqual(0, conn.in_flight) - def test_spawn_when_at_max(self): host = Mock(spec=Host, address='ip1') session = self.make_session() @@ -147,8 +107,8 @@ def test_spawn_when_at_max(self): # core conns = 1, max conns = 2 session.cluster.get_max_connections_per_host.return_value = 2 - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) self.assertEqual(1, conn.in_flight) @@ -160,16 +120,18 @@ def test_spawn_when_at_max(self): # purposes of this test, as long as it results in a new connection # creation being scheduled self.assertRaises(NoConnectionsAvailable, pool.borrow_connection, 0) - session.submit.assert_called_once_with(pool._create_new_connection) + if not self.uses_single_connection: + session.submit.assert_called_once_with(pool._create_new_connection) def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) conn.is_defunct = True @@ -177,18 +139,19 @@ def test_return_defunct_connection(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - conn.close.assert_called_once() - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) self.assertFalse(pool.is_shutdown) def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False, + orphaned_threshold_reached=False) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) conn.is_defunct = True @@ -196,19 +159,20 @@ def test_return_defunct_connection_on_down_host(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - session.cluster.signal_connection_failure.assert_called_once() - conn.close.assert_called_once() + self.assertTrue(session.cluster.signal_connection_failure.call_args) + self.assertTrue(conn.close.call_args) self.assertFalse(session.submit.called) self.assertTrue(pool.is_shutdown) def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100, + signaled_error=False, orphaned_threshold_reached=False) session.cluster.connection_factory.return_value = conn - pool = HostConnectionPool(host, HostDistance.LOCAL, session) - session.cluster.connection_factory.assert_called_once_with(host.address) + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) conn.is_closed = True @@ -216,7 +180,7 @@ def test_return_closed_connection(self): pool.return_connection(conn) # a new creation should be scheduled - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) self.assertFalse(pool.is_shutdown) def test_host_instantiations(self): @@ -240,3 +204,59 @@ def test_host_equality(self): self.assertEqual(a, b, 'Two Host instances should be equal when sharing.') self.assertNotEqual(a, c, 'Two Host instances should NOT be equal when using two different addresses.') self.assertNotEqual(b, c, 'Two Host instances should NOT be equal when using two different addresses.') + + +class HostConnectionPoolTests(_PoolTests): + __test__ = True + PoolImpl = HostConnectionPool + uses_single_connection = False + + def test_all_connections_trashed(self): + host = Mock(spec=Host, address='ip1') + session = self.make_session() + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, + lock=Lock()) + session.cluster.connection_factory.return_value = conn + session.cluster.get_core_connections_per_host.return_value = 1 + + # manipulate the core connection setting so that we can + # trash the only connection + pool = self.PoolImpl(host, HostDistance.LOCAL, session) + session.cluster.get_core_connections_per_host.return_value = 0 + pool._maybe_trash_connection(conn) + session.cluster.get_core_connections_per_host.return_value = 1 + + submit_called = Event() + + def fire_event(*args, **kwargs): + submit_called.set() + + session.submit.side_effect = fire_event + + def get_conn(): + conn.reset_mock() + c, request_id = pool.borrow_connection(1.0) + self.assertIs(conn, c) + self.assertEqual(1, conn.in_flight) + conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') + pool.return_connection(c) + + t = Thread(target=get_conn) + t.start() + + submit_called.wait() + self.assertEqual(1, pool._scheduled_for_creation) + session.submit.assert_called_once_with(pool._create_new_connection) + + # now run the create_new_connection call + pool._create_new_connection() + + t.join() + self.assertEqual(0, conn.in_flight) + + +class HostConnectionTests(_PoolTests): + __test__ = True + PoolImpl = HostConnection + uses_single_connection = True + diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 626c38b22f..9b44bb5ac2 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +15,9 @@ # limitations under the License. import sys -from cassandra.marshal import bitlength -from cassandra.protocol import MAX_SUPPORTED_VERSION +from cassandra import ProtocolVersion -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest import platform from datetime import datetime, date @@ -134,11 +132,6 @@ def test_marshalling(self): msg='Marshaller for %s (%s) gave wrong type (%s instead of %s)' % (valtype, marshaller, type(whatwegot), type(serializedval))) - def test_bitlength(self): - self.assertEqual(bitlength(9), 4) - self.assertEqual(bitlength(-10), 0) - self.assertEqual(bitlength(0), 0) - def test_date(self): # separate test because it will deserialize as datetime self.assertEqual(DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 1), 1), datetime(2015, 11, 2)) @@ -146,17 +139,9 @@ def test_date(self): def test_decimal(self): # testing implicit numeric conversion # int, tuple(sign, digits, exp), float - converted_types = (10001, (0, (1, 0, 0, 0, 0, 1), -3), 100.1) - - if sys.version_info < (2, 7): - # Decimal in Python 2.6 does not accept floats for lossless initialization - # Just verifying expected exception here - f = converted_types[-1] - self.assertIsInstance(f, float) - self.assertRaises(TypeError, DecimalType.to_binary, f, MAX_SUPPORTED_VERSION) - converted_types = converted_types[:-1] + converted_types = (10001, (0, (1, 0, 0, 0, 0, 1), -3), 100.1, -87.629798) - for proto_ver in range(1, MAX_SUPPORTED_VERSION + 1): + for proto_ver in range(1, ProtocolVersion.MAX_SUPPORTED + 1): for n in converted_types: expected = Decimal(n) self.assertEqual(DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver), expected) diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index c2e3513fd1..76e47a4331 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -1,38 +1,73 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from mock import Mock +from binascii import unhexlify +import logging +from unittest.mock import Mock import os -import six +import timeit import cassandra +from cassandra.cqltypes import strip_frozen +from cassandra.marshal import uint16_unpack, uint16_pack from cassandra.metadata import (Murmur3Token, MD5Token, BytesToken, ReplicationStrategy, NetworkTopologyStrategy, SimpleStrategy, LocalStrategy, protect_name, protect_names, protect_value, is_valid_name, UserType, KeyspaceMetadata, get_schema_parser, - _UnknownStrategy) + _UnknownStrategy, ColumnMetadata, TableMetadata, + IndexMetadata, Function, Aggregate, + Metadata, TokenMap, ReplicationFactor) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host +log = logging.getLogger(__name__) + + +class ReplicationFactorTest(unittest.TestCase): + + def test_replication_factor_parsing(self): + rf = ReplicationFactor.create('3') + self.assertEqual(rf.all_replicas, 3) + self.assertEqual(rf.full_replicas, 3) + self.assertEqual(rf.transient_replicas, None) + self.assertEqual(str(rf), '3') + + rf = ReplicationFactor.create('3/1') + self.assertEqual(rf.all_replicas, 3) + self.assertEqual(rf.full_replicas, 2) + self.assertEqual(rf.transient_replicas, 1) + self.assertEqual(str(rf), '3/1') + + self.assertRaises(ValueError, ReplicationFactor.create, '3/') + self.assertRaises(ValueError, ReplicationFactor.create, 'a/1') + self.assertRaises(ValueError, ReplicationFactor.create, 'a') + self.assertRaises(ValueError, ReplicationFactor.create, '3/a') + + def test_replication_factor_equality(self): + self.assertEqual(ReplicationFactor.create('3/1'), ReplicationFactor.create('3/1')) + self.assertEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3')) + self.assertNotEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3/1')) + self.assertNotEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3/1')) + + + class StrategiesTest(unittest.TestCase): @classmethod @@ -76,6 +111,93 @@ def test_replication_strategy(self): self.assertRaises(NotImplementedError, rs.make_token_replica_map, None, None) self.assertRaises(NotImplementedError, rs.export_for_schema) + def test_simple_replication_type_parsing(self): + """ Test equality between passing numeric and string replication factor for simple strategy """ + rs = ReplicationStrategy() + + simple_int = rs.create('SimpleStrategy', {'replication_factor': 3}) + simple_str = rs.create('SimpleStrategy', {'replication_factor': '3'}) + + self.assertEqual(simple_int.export_for_schema(), simple_str.export_for_schema()) + self.assertEqual(simple_int, simple_str) + + # make token replica map + ring = [MD5Token(0), MD5Token(1), MD5Token(2)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + token_to_host = dict(zip(ring, hosts)) + self.assertEqual( + simple_int.make_token_replica_map(token_to_host, ring), + simple_str.make_token_replica_map(token_to_host, ring) + ) + + def test_transient_replication_parsing(self): + """ Test that we can PARSE a transient replication factor for SimpleStrategy """ + rs = ReplicationStrategy() + + simple_transient = rs.create('SimpleStrategy', {'replication_factor': '3/1'}) + self.assertEqual(simple_transient.replication_factor_info, ReplicationFactor(3, 1)) + self.assertEqual(simple_transient.replication_factor, 2) + self.assertIn("'replication_factor': '3/1'", simple_transient.export_for_schema()) + + simple_str = rs.create('SimpleStrategy', {'replication_factor': '2'}) + self.assertNotEqual(simple_transient, simple_str) + + # make token replica map + ring = [MD5Token(0), MD5Token(1), MD5Token(2)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + token_to_host = dict(zip(ring, hosts)) + self.assertEqual( + simple_transient.make_token_replica_map(token_to_host, ring), + simple_str.make_token_replica_map(token_to_host, ring) + ) + + def test_nts_replication_parsing(self): + """ Test equality between passing numeric and string replication factor for NTS """ + rs = ReplicationStrategy() + + nts_int = rs.create('NetworkTopologyStrategy', {'dc1': 3, 'dc2': 5}) + nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) + + self.assertEqual(nts_int.dc_replication_factors['dc1'], 3) + self.assertEqual(nts_str.dc_replication_factors['dc1'], 3) + self.assertEqual(nts_int.dc_replication_factors_info['dc1'], ReplicationFactor(3)) + self.assertEqual(nts_str.dc_replication_factors_info['dc1'], ReplicationFactor(3)) + + self.assertEqual(nts_int.export_for_schema(), nts_str.export_for_schema()) + self.assertEqual(nts_int, nts_str) + + # make token replica map + ring = [MD5Token(0), MD5Token(1), MD5Token(2)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + token_to_host = dict(zip(ring, hosts)) + self.assertEqual( + nts_int.make_token_replica_map(token_to_host, ring), + nts_str.make_token_replica_map(token_to_host, ring) + ) + + def test_nts_transient_parsing(self): + """ Test that we can PARSE a transient replication factor for NTS """ + rs = ReplicationStrategy() + + nts_transient = rs.create('NetworkTopologyStrategy', {'dc1': '3/1', 'dc2': '5/1'}) + self.assertEqual(nts_transient.dc_replication_factors_info['dc1'], ReplicationFactor(3, 1)) + self.assertEqual(nts_transient.dc_replication_factors_info['dc2'], ReplicationFactor(5, 1)) + self.assertEqual(nts_transient.dc_replication_factors['dc1'], 2) + self.assertEqual(nts_transient.dc_replication_factors['dc2'], 4) + self.assertIn("'dc1': '3/1', 'dc2': '5/1'", nts_transient.export_for_schema()) + + nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) + self.assertNotEqual(nts_transient, nts_str) + + # make token replica map + ring = [MD5Token(0), MD5Token(1), MD5Token(2)] + hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] + token_to_host = dict(zip(ring, hosts)) + self.assertEqual( + nts_transient.make_token_replica_map(token_to_host, ring), + nts_str.make_token_replica_map(token_to_host, ring) + ) + def test_nts_make_token_replica_map(self): token_to_host_owner = {} @@ -111,6 +233,45 @@ def test_nts_make_token_replica_map(self): self.assertItemsEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) + def test_nts_token_performance(self): + """ + Tests to ensure that when rf exceeds the number of nodes available, that we dont' + needlessly iterate trying to construct tokens for nodes that don't exist. + + @since 3.7 + @jira_ticket PYTHON-379 + @expected_result timing with 1500 rf should be same/similar to 3rf if we have 3 nodes + + @test_category metadata + """ + + token_to_host_owner = {} + ring = [] + dc1hostnum = 3 + current_token = 0 + vnodes_per_host = 500 + for i in range(dc1hostnum): + + host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy) + host.set_location_info('dc1', "rack1") + for vnode_num in range(vnodes_per_host): + md5_token = MD5Token(current_token+vnode_num) + token_to_host_owner[md5_token] = host + ring.append(md5_token) + current_token += 1000 + + nts = NetworkTopologyStrategy({'dc1': 3}) + start_time = timeit.default_timer() + nts.make_token_replica_map(token_to_host_owner, ring) + elapsed_base = timeit.default_timer() - start_time + + nts = NetworkTopologyStrategy({'dc1': 1500}) + start_time = timeit.default_timer() + nts.make_token_replica_map(token_to_host_owner, ring) + elapsed_bad = timeit.default_timer() - start_time + difference = elapsed_bad - elapsed_base + self.assertTrue(difference < 1 and difference > -1) + def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} @@ -258,6 +419,41 @@ def test_is_valid_name(self): self.assertEqual(is_valid_name(keyword), False) +class GetReplicasTest(unittest.TestCase): + def _get_replicas(self, token_klass): + tokens = [token_klass(i) for i in range(0, (2 ** 127 - 1), 2 ** 125)] + hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] + token_to_primary_replica = dict(zip(tokens, hosts)) + keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"}) + metadata = Mock(spec=Metadata, keyspaces={'ks': keyspace}) + token_map = TokenMap(token_klass, token_to_primary_replica, tokens, metadata) + + # tokens match node tokens exactly + for token, expected_host in zip(tokens, hosts): + replicas = token_map.get_replicas("ks", token) + self.assertEqual(set(replicas), {expected_host}) + + # shift the tokens back by one + for token, expected_host in zip(tokens, hosts): + replicas = token_map.get_replicas("ks", token_klass(token.value - 1)) + self.assertEqual(set(replicas), {expected_host}) + + # shift the tokens forward by one + for i, token in enumerate(tokens): + replicas = token_map.get_replicas("ks", token_klass(token.value + 1)) + expected_host = hosts[(i + 1) % len(hosts)] + self.assertEqual(set(replicas), {expected_host}) + + def test_murmur3_tokens(self): + self._get_replicas(Murmur3Token) + + def test_md5_tokens(self): + self._get_replicas(MD5Token) + + def test_bytes_tokens(self): + self._get_replicas(BytesToken) + + class Murmur3TokensTest(unittest.TestCase): def test_murmur3_init(self): @@ -290,11 +486,11 @@ def test_murmur3_c(self): raise unittest.SkipTest('The cmurmur3 extension is not available') def _verify_hash(self, fn): - self.assertEqual(fn(six.b('123')), -7468325962851647638) + self.assertEqual(fn(b'123'), -7468325962851647638) self.assertEqual(fn(b'\x00\xff\x10\xfa\x99' * 10), 5837342703291459765) self.assertEqual(fn(b'\xfe' * 8), -8927430733708461935) self.assertEqual(fn(b'\x10' * 8), 1446172840243228796) - self.assertEqual(fn(six.b(str(cassandra.metadata.MAX_LONG))), 7162290910810015547) + self.assertEqual(fn(str(cassandra.metadata.MAX_LONG).encode()), 7162290910810015547) class MD5TokensTest(unittest.TestCase): @@ -309,17 +505,32 @@ def test_md5_tokens(self): class BytesTokensTest(unittest.TestCase): def test_bytes_tokens(self): - bytes_token = BytesToken(str(cassandra.metadata.MIN_LONG - 1)) + bytes_token = BytesToken(unhexlify(b'01')) + self.assertEqual(bytes_token.value, b'\x01') + self.assertEqual(str(bytes_token), "" % bytes_token.value) self.assertEqual(bytes_token.hash_fn('123'), '123') self.assertEqual(bytes_token.hash_fn(123), 123) self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG)) - self.assertEqual(str(bytes_token), "") - try: - bytes_token = BytesToken(cassandra.metadata.MIN_LONG - 1) - self.fail('Tokens for ByteOrderedPartitioner should be only strings') - except TypeError: - pass + def test_from_string(self): + from_unicode = BytesToken.from_string('0123456789abcdef') + from_bin = BytesToken.from_string(b'0123456789abcdef') + self.assertEqual(from_unicode, from_bin) + self.assertIsInstance(from_unicode.value, bytes) + self.assertIsInstance(from_bin.value, bytes) + + def test_comparison(self): + tok = BytesToken.from_string('0123456789abcdef') + token_high_order = uint16_unpack(tok.value[0:2]) + self.assertLess(BytesToken(uint16_pack(token_high_order - 1)), tok) + self.assertGreater(BytesToken(uint16_pack(token_high_order + 1)), tok) + + def test_comparison_unicode(self): + value = b'\'_-()"\xc2\xac' + t0 = BytesToken(value) + t1 = BytesToken.from_string('00') + self.assertGreater(t0, t1) + self.assertFalse(t0 < t1) class KeyspaceMetadataTest(unittest.TestCase): @@ -372,6 +583,35 @@ def test_as_cql_query_name_escaping(self): self.assertEqual('CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)', udt.as_cql_query(formatted=False)) +class UserDefinedFunctionTest(unittest.TestCase): + def test_as_cql_query_removes_frozen(self): + func = Function( + "ks1", "myfunction", ["frozen>"], ["a"], + "int", "java", "return 0;", True, False, False, False + ) + expected_result = ( + "CREATE FUNCTION ks1.myfunction(a tuple) " + "CALLED ON NULL INPUT " + "RETURNS int " + "LANGUAGE java " + "AS $$return 0;$$" + ) + self.assertEqual(expected_result, func.as_cql_query(formatted=False)) + + +class UserDefinedAggregateTest(unittest.TestCase): + def test_as_cql_query_removes_frozen(self): + aggregate = Aggregate("ks1", "myaggregate", ["frozen>"], "statefunc", "frozen>", "finalfunc", "(0)", "tuple", False) + expected_result = ( + "CREATE AGGREGATE ks1.myaggregate(tuple) " + "SFUNC statefunc " + "STYPE tuple " + "FINALFUNC finalfunc " + "INITCOND (0)" + ) + self.assertEqual(expected_result, aggregate.as_cql_query(formatted=False)) + + class IndexTest(unittest.TestCase): def test_build_index_as_cql(self): @@ -380,9 +620,7 @@ def test_build_index_as_cql(self): column_meta.table.name = 'table_name_here' column_meta.table.keyspace_name = 'keyspace_name_here' column_meta.table.columns = {column_meta.name: column_meta} - connection = Mock() - connection.server_version = '2.1.0' - parser = get_schema_parser(connection, 0.1) + parser = get_schema_parser(Mock(), '2.1.0', None, 0.1) row = {'index_name': 'index_name_here', 'index_type': 'index_type_here'} index_meta = parser._build_index_metadata(column_meta, row) @@ -394,3 +632,230 @@ def test_build_index_as_cql(self): index_meta = parser._build_index_metadata(column_meta, row) self.assertEqual(index_meta.as_cql_query(), "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'") + + +class UnicodeIdentifiersTests(unittest.TestCase): + """ + Exercise cql generation with unicode characters. Keyspace, Table, and Index names + cannot have special chars because C* names files by those identifiers, but they are + tested anyway. + + Looking for encoding errors like PYTHON-447 + """ + + name = b'\'_-()"\xc2\xac'.decode('utf-8') + + def test_keyspace_name(self): + km = KeyspaceMetadata(self.name, False, 'SimpleStrategy', {'replication_factor': 1}) + km.export_as_string() + + def test_table_name(self): + tm = TableMetadata(self.name, self.name) + tm.export_as_string() + + def test_column_name_single_partition(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + tm.export_as_string() + + def test_column_name_single_partition_single_clustering(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + cm = ColumnMetadata(tm, self.name + 'x', u'int') + tm.columns[cm.name] = cm + tm.clustering_key.append(cm) + tm.export_as_string() + + def test_column_name_multiple_partition(self): + tm = TableMetadata('ks', 'table') + cm = ColumnMetadata(tm, self.name, u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + cm = ColumnMetadata(tm, self.name + 'x', u'int') + tm.columns[cm.name] = cm + tm.partition_key.append(cm) + tm.export_as_string() + + def test_index(self): + im = IndexMetadata(self.name, self.name, self.name, kind='', index_options={'target': self.name}) + log.debug(im.export_as_string()) + im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class'}) + log.debug(im.export_as_string()) + # PYTHON-1008 + im = IndexMetadata(self.name, self.name, self.name, kind='CUSTOM', index_options={'target': self.name, 'class_name': 'Class', 'delimiter': self.name}) + log.debug(im.export_as_string()) + + def test_function(self): + fm = Function(keyspace=self.name, name=self.name, + argument_types=(u'int', u'int'), + argument_names=(u'x', u'y'), + return_type=u'int', language=u'language', + body=self.name, called_on_null_input=False, + deterministic=True, + monotonic=False, monotonic_on=(u'x',)) + fm.export_as_string() + + def test_aggregate(self): + am = Aggregate(self.name, self.name, (u'text',), self.name, u'text', self.name, self.name, u'text', True) + am.export_as_string() + + def test_user_type(self): + um = UserType(self.name, self.name, [self.name, self.name], [u'int', u'text']) + um.export_as_string() + + +class FunctionToCQLTests(unittest.TestCase): + + base_vars = { + 'keyspace': 'ks_name', + 'name': 'function_name', + 'argument_types': (u'int', u'int'), + 'argument_names': (u'x', u'y'), + 'return_type': u'int', + 'language': u'language', + 'body': 'body', + 'called_on_null_input': False, + 'deterministic': True, + 'monotonic': False, + 'monotonic_on': () + } + + def _function_with_kwargs(self, **kwargs): + return Function(**dict(self.base_vars, + **kwargs) + ) + + def test_non_monotonic(self): + self.assertNotIn( + 'MONOTONIC', + self._function_with_kwargs( + monotonic=False, + monotonic_on=() + ).export_as_string() + ) + + def test_monotonic_all(self): + mono_function = self._function_with_kwargs( + monotonic=True, + monotonic_on=() + ) + self.assertIn( + 'MONOTONIC LANG', + mono_function.as_cql_query(formatted=False) + ) + self.assertIn( + 'MONOTONIC\n LANG', + mono_function.as_cql_query(formatted=True) + ) + + def test_monotonic_one(self): + mono_on_function = self._function_with_kwargs( + monotonic=False, + monotonic_on=('x',) + ) + self.assertIn( + 'MONOTONIC ON x LANG', + mono_on_function.as_cql_query(formatted=False) + ) + self.assertIn( + 'MONOTONIC ON x\n LANG', + mono_on_function.as_cql_query(formatted=True) + ) + + def test_nondeterministic(self): + self.assertNotIn( + 'DETERMINISTIC', + self._function_with_kwargs( + deterministic=False + ).as_cql_query(formatted=False) + ) + + def test_deterministic(self): + self.assertIn( + 'DETERMINISTIC', + self._function_with_kwargs( + deterministic=True + ).as_cql_query(formatted=False) + ) + self.assertIn( + 'DETERMINISTIC\n', + self._function_with_kwargs( + deterministic=True + ).as_cql_query(formatted=True) + ) + + +class AggregateToCQLTests(unittest.TestCase): + base_vars = { + 'keyspace': 'ks_name', + 'name': 'function_name', + 'argument_types': (u'int', u'int'), + 'state_func': 'funcname', + 'state_type': u'int', + 'return_type': u'int', + 'final_func': None, + 'initial_condition': '0', + 'deterministic': True + } + + def _aggregate_with_kwargs(self, **kwargs): + return Aggregate(**dict(self.base_vars, + **kwargs) + ) + + def test_nondeterministic(self): + self.assertNotIn( + 'DETERMINISTIC', + self._aggregate_with_kwargs( + deterministic=False + ).as_cql_query(formatted=True) + ) + + def test_deterministic(self): + for formatted in (True, False): + query = self._aggregate_with_kwargs( + deterministic=True + ).as_cql_query(formatted=formatted) + self.assertTrue(query.endswith('DETERMINISTIC'), + msg="'DETERMINISTIC' not found in {}".format(query) + ) + + +class HostsTests(unittest.TestCase): + def test_iterate_all_hosts_and_modify(self): + """ + PYTHON-572 + """ + metadata = Metadata() + metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) + metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) + + self.assertEqual(len(metadata.all_hosts()), 2) + + for host in metadata.all_hosts(): # this would previously raise in Py3 + metadata.remove_host(host) + + self.assertEqual(len(metadata.all_hosts()), 0) + + +class MetadataHelpersTest(unittest.TestCase): + """ For any helper functions that need unit tests """ + def test_strip_frozen(self): + self.longMessage = True + + argument_to_expected_results = [ + ('int', 'int'), + ('tuple', 'tuple'), + (r'map<"!@#$%^&*()[]\ frozen >>>", int>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), # A valid UDT name + ('frozen>', 'tuple'), + (r'frozen>>", int>>', r'map<"!@#$%^&*()[]\ frozen >>>", int>'), + ('frozen>, int>>, frozen>>>>>', + 'map, int>, map>>'), + ] + for argument, expected_result in argument_to_expected_results: + result = strip_frozen(argument) + self.assertEqual(result, expected_result, "strip_frozen() arg: {}".format(argument)) diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py index 6fa43ccdfe..a26994dd7b 100644 --- a/tests/unit/test_orderedmap.py +++ b/tests/unit/test_orderedmap.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra.util import OrderedMap, OrderedMapSerializedKey from cassandra.cqltypes import EMPTY, UTF8Type, lookup_casstype -import six class OrderedMapTest(unittest.TestCase): def test_init(self): @@ -121,11 +119,11 @@ def test_iter(self): itr = iter(om) self.assertEqual(sum([1 for _ in itr]), len(keys)) - self.assertRaises(StopIteration, six.next, itr) + self.assertRaises(StopIteration, next, itr) self.assertEqual(list(iter(om)), keys) - self.assertEqual(list(six.iteritems(om)), items) - self.assertEqual(list(six.itervalues(om)), values) + self.assertEqual(list(om.items()), items) + self.assertEqual(list(om.values()), values) def test_len(self): self.assertEqual(len(OrderedMap()), 0) diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index d48b5d9573..fd44728c25 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra.encoder import Encoder from cassandra.protocol import ColumnMetadata @@ -24,9 +23,6 @@ from cassandra.cqltypes import Int32Type from cassandra.util import OrderedDict -from six.moves import xrange -import six - class ParamBindingTest(unittest.TestCase): @@ -43,7 +39,7 @@ def test_sequence_param(self): self.assertEqual(result, "(1, 'a', 2.0)") def test_generator_param(self): - result = bind_params("%s", ((i for i in xrange(3)),), Encoder()) + result = bind_params("%s", ((i for i in range(3)),), Encoder()) self.assertEqual(result, "[0, 1, 2]") def test_none_param(self): @@ -77,21 +73,21 @@ def test_float_precision(self): class BoundStatementTestV1(unittest.TestCase): - protocol_version=1 + protocol_version = 1 @classmethod def setUpClass(cls): - cls.prepared = PreparedStatement(column_metadata=[ - ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), - ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), - ColumnMetadata('keyspace', 'cf', 'v0', Int32Type) - ], + column_metadata = [ColumnMetadata('keyspace', 'cf', 'rk0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'rk1', Int32Type), + ColumnMetadata('keyspace', 'cf', 'ck0', Int32Type), + ColumnMetadata('keyspace', 'cf', 'v0', Int32Type)] + cls.prepared = PreparedStatement(column_metadata=column_metadata, query_id=None, routing_key_indexes=[1, 0], query=None, keyspace='keyspace', - protocol_version=cls.protocol_version) + protocol_version=cls.protocol_version, result_metadata=None, + result_metadata_id=None) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): @@ -130,7 +126,9 @@ def test_inherit_fetch_size(self): routing_key_indexes=[], query=None, keyspace=keyspace, - protocol_version=self.protocol_version) + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) self.assertEqual(1234, bound_statement.fetch_size) @@ -150,7 +148,7 @@ def test_missing_value(self): def test_extra_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict - self.assertEqual(self.bound.values, [six.b('\x00') * 4] * 4) # four encoded zeros + self.assertEqual(self.bound.values, [b'\x00' * 4] * 4) # four encoded zeros self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123)) def test_values_none(self): @@ -163,7 +161,9 @@ def test_values_none(self): routing_key_indexes=[], query=None, keyspace='whatever', - protocol_version=self.protocol_version) + protocol_version=self.protocol_version, + result_metadata=None, + result_metadata_id=None) bound = prepared_statement.bind(None) self.assertListEqual(bound.values, []) @@ -182,15 +182,15 @@ def test_unset_value(self): class BoundStatementTestV2(BoundStatementTestV1): - protocol_version=2 + protocol_version = 2 class BoundStatementTestV3(BoundStatementTestV1): - protocol_version=3 + protocol_version = 3 class BoundStatementTestV4(BoundStatementTestV1): - protocol_version=4 + protocol_version = 4 def test_dict_missing_routing_key(self): # in v4 it implicitly binds UNSET_VALUE for missing items, @@ -212,6 +212,9 @@ def test_unset_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) self.assertEqual(self.bound.values[-1], UNSET_VALUE) - old_values = self.bound.values self.bound.bind((0, 0, 0, UNSET_VALUE)) self.assertEqual(self.bound.values[-1], UNSET_VALUE) + + +class BoundStatementTestV5(BoundStatementTestV4): + protocol_version = 5 diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index a9406cf790..792268cd7f 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,33 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from itertools import islice, cycle -from mock import Mock +from unittest.mock import Mock, patch, call from random import randint -import six +from _thread import LockType import sys import struct from threading import Thread from cassandra import ConsistencyLevel from cassandra.cluster import Cluster +from cassandra.connection import DefaultEndPoint from cassandra.metadata import Metadata -from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, - LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy) + LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy) from cassandra.pool import Host from cassandra.query import Statement -from six.moves import xrange - class LoadBalancingPolicyTest(unittest.TestCase): def test_non_implemented(self): @@ -47,7 +46,7 @@ def test_non_implemented(self): """ policy = LoadBalancingPolicy() - host = Host("ip1", SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") self.assertRaises(NotImplementedError, policy.distance, host) @@ -75,7 +74,7 @@ def test_multiple_query_plans(self): hosts = [0, 1, 2, 3] policy = RoundRobinPolicy() policy.populate(None, hosts) - for i in xrange(20): + for i in range(20): qplan = list(policy.make_query_plan()) self.assertEqual(sorted(qplan), hosts) @@ -104,11 +103,13 @@ def test_thread_safety(self): def check_query_plan(): for i in range(100): qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), hosts) + self.assertEqual(sorted(qplan), list(hosts)) threads = [Thread(target=check_query_plan) for i in range(4)] - map(lambda t: t.start(), threads) - map(lambda t: t.join(), threads) + for t in threads: + t.start() + for t in threads: + t.join() def test_thread_safety_during_modification(self): hosts = range(100) @@ -119,17 +120,17 @@ def test_thread_safety_during_modification(self): def check_query_plan(): try: - for i in xrange(100): + for i in range(100): list(policy.make_query_plan()) except Exception as exc: errors.append(exc) def host_up(): - for i in xrange(1000): + for i in range(1000): policy.on_up(randint(0, 99)) def host_down(): - for i in xrange(1000): + for i in range(1000): policy.on_down(randint(0, 99)) threads = [] @@ -140,7 +141,7 @@ def host_down(): # make the GIL switch after every instruction, maximizing # the chance of race conditions - check = six.PY2 or '__pypy__' in sys.builtin_module_names + check = '__pypy__' in sys.builtin_module_names if check: original_interval = sys.getcheckinterval() else: @@ -151,8 +152,10 @@ def host_down(): sys.setcheckinterval(0) else: sys.setswitchinterval(0.0001) - map(lambda t: t.start(), threads) - map(lambda t: t.join(), threads) + for t in threads: + t.start() + for t in threads: + t.join() finally: if check: sys.setcheckinterval(original_interval) @@ -182,7 +185,7 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_no_remote(self): hosts = [] for i in range(4): - h = Host(i, SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -192,7 +195,7 @@ def test_no_remote(self): self.assertEqual(sorted(qplan), sorted(hosts)) def test_with_remotes(self): - hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -201,7 +204,7 @@ def test_with_remotes(self): local_hosts = set(h for h in hosts if h.datacenter == "dc1") remote_hosts = set(h for h in hosts if h.datacenter != "dc1") - # allow all of the remote hosts to be used + # allow all the remote hosts to be used policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) @@ -227,14 +230,14 @@ def test_with_remotes(self): def test_get_distance(self): policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) - host = Host("ip1", SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) self.assertEqual(policy.distance(host), HostDistance.LOCAL) # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host("ip2", SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) remote_host.set_location_info("dc2", "rack1") self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) @@ -248,14 +251,14 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host("ip3", SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) second_remote_host.set_location_info("dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) def test_status_updates(self): - hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -266,11 +269,11 @@ def test_status_updates(self): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(4, SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(5, SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -293,7 +296,7 @@ def test_status_updates(self): self.assertEqual(qplan, []) def test_modification_during_generation(self): - hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -302,12 +305,12 @@ def test_modification_during_generation(self): policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=3) policy.populate(Mock(), hosts) - # The general concept here is to change thee internal state of the + # The general concept here is to change the internal state of the # policy during plan generation. In this case we use a grey-box # approach that changes specific things during known phases of the # generator. - new_host = Host(4, SimpleConvictionPolicy) + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) new_host.set_location_info("dc1", "rack1") # new local before iteration @@ -418,8 +421,7 @@ def test_modification_during_generation(self): policy.on_up(hosts[2]) policy.on_up(hosts[3]) - - another_host = Host(5, SimpleConvictionPolicy) + another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) another_host.set_location_info("dc3", "rack1") new_host.set_location_info("dc3", "rack1") @@ -453,7 +455,7 @@ def test_no_live_nodes(self): hosts = [] for i in range(4): - h = Host(i, SimpleConvictionPolicy) + h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) h.set_location_info("dc1", "rack1") hosts.append(h) @@ -478,12 +480,12 @@ def test_no_nodes(self): self.assertEqual(qplan, []) def test_default_dc(self): - host_local = Host(1, SimpleConvictionPolicy, 'local') - host_remote = Host(2, SimpleConvictionPolicy, 'remote') - host_none = Host(1, SimpleConvictionPolicy) + host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') + host_remote = Host(DefaultEndPoint(2), SimpleConvictionPolicy, 'remote') + host_none = Host(DefaultEndPoint(1), SimpleConvictionPolicy) # contact point is '1' - cluster = Mock(contact_points=[1]) + cluster = Mock(endpoints_resolved=[DefaultEndPoint(1)]) # contact DC first policy = DCAwareRoundRobinPolicy() @@ -523,7 +525,7 @@ class TokenAwarePolicyTest(unittest.TestCase): def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -554,7 +556,7 @@ def get_replicas(keyspace, packed_key): def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) - hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() for h in hosts[:2]: @@ -605,7 +607,7 @@ def test_get_distance(self): """ policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) - host = Host("ip1", SimpleConvictionPolicy) + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") policy.populate(self.FakeCluster(), [host]) @@ -613,7 +615,7 @@ def test_get_distance(self): self.assertEqual(policy.distance(host), HostDistance.LOCAL) # used_hosts_per_remote_dc is set to 0, so ignore it - remote_host = Host("ip2", SimpleConvictionPolicy) + remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) remote_host.set_location_info("dc2", "rack1") self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) @@ -627,7 +629,7 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED - second_remote_host = Host("ip3", SimpleConvictionPolicy) + second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy) second_remote_host.set_location_info("dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) @@ -638,7 +640,7 @@ def test_status_updates(self): Same test as DCAwareRoundRobinPolicyTest.test_status_updates() """ - hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -649,11 +651,11 @@ def test_status_updates(self): policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(4, SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(5, SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) @@ -676,7 +678,7 @@ def test_status_updates(self): self.assertEqual(qplan, []) def test_statement_keyspace(self): - hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -730,6 +732,73 @@ def test_statement_keyspace(self): self.assertEqual(replicas + hosts[:2], qplan) cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + def test_shuffles_if_given_keyspace_and_routing_key(self): + """ + Test to validate the hosts are shuffled when `shuffle_replicas` is truthy + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because the keyspace and the + routing key are set + + @test_category policy + """ + self._assert_shuffle(keyspace='keyspace', routing_key='routing_key') + + def test_no_shuffle_if_given_no_keyspace(self): + """ + Test to validate the hosts are not shuffled when no keyspace is provided + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because keyspace is None + + @test_category policy + """ + self._assert_shuffle(keyspace=None, routing_key='routing_key') + + def test_no_shuffle_if_given_no_routing_key(self): + """ + Test to validate the hosts are not shuffled when no routing_key is provided + @since 3.8 + @jira_ticket PYTHON-676 + @expected_result shuffle should be called, because routing_key is None + + @test_category policy + """ + self._assert_shuffle(keyspace='keyspace', routing_key=None) + + @patch('cassandra.policies.shuffle') + def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + replicas = hosts[2:] + cluster.metadata.get_replicas.return_value = replicas + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + cluster.metadata.get_replicas.reset_mock() + child_policy.make_query_plan.reset_mock() + query = Statement(routing_key=routing_key) + qplan = list(policy.make_query_plan(keyspace, query)) + if keyspace is None or routing_key is None: + self.assertEqual(hosts, qplan) + self.assertEqual(cluster.metadata.get_replicas.call_count, 0) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + self.assertEqual(patched_shuffle.call_count, 0) + else: + self.assertEqual(set(replicas), set(qplan[:2])) + self.assertEqual(hosts[:2], qplan[2:]) + child_policy.make_query_plan.assert_called_once_with(keyspace, query) + self.assertEqual(patched_shuffle.call_count, 1) + class ConvictionPolicyTest(unittest.TestCase): def test_not_implemented(self): @@ -810,36 +879,93 @@ def test_schedule_infinite_attempts(self): class ExponentialReconnectionPolicyTest(unittest.TestCase): + def _assert_between(self, value, min, max): + self.assertTrue(min <= value <= max) + def test_bad_vals(self): self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) - self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1) def test_schedule_no_max(self): - base_delay = 2 - max_delay = 100 + base_delay = 2.0 + max_delay = 100.0 test_iter = 10000 policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=None) sched_slice = list(islice(policy.new_schedule(), 0, test_iter)) - self.assertEqual(sched_slice[0], base_delay) - self.assertEqual(sched_slice[-1], max_delay) + self._assert_between(sched_slice[0], base_delay*0.85, base_delay*1.15) + self._assert_between(sched_slice[-1], max_delay*0.85, max_delay*1.15) self.assertEqual(len(sched_slice), test_iter) def test_schedule_with_max(self): - base_delay = 2 - max_delay = 100 + base_delay = 2.0 + max_delay = 100.0 max_attempts = 64 policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) schedule = list(policy.new_schedule()) self.assertEqual(len(schedule), max_attempts) for i, delay in enumerate(schedule): if i == 0: - self.assertEqual(delay, base_delay) + self._assert_between(delay, base_delay*0.85, base_delay*1.15) elif i < 6: - self.assertEqual(delay, schedule[i - 1] * 2) + value = base_delay * (2 ** i) + self._assert_between(delay, value*85/100, value*1.15) else: - self.assertEqual(delay, max_delay) + self._assert_between(delay, max_delay*85/100, max_delay*1.15) + + def test_schedule_exactly_one_attempt(self): + base_delay = 2.0 + max_delay = 100.0 + max_attempts = 1 + policy = ExponentialReconnectionPolicy( + base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts + ) + self.assertEqual(len(list(policy.new_schedule())), 1) + + def test_schedule_overflow(self): + """ + Test to verify an OverflowError is handled correctly + in the ExponentialReconnectionPolicy + @since 3.10 + @jira_ticket PYTHON-707 + @expected_result all numbers should be less than sys.float_info.max + since that's the biggest max we can possibly have as that argument must be a float. + Note that is possible for a float to be inf. + + @test_category policy + """ + + # This should lead to overflow + # Note that this may not happen in the fist iterations + # as sys.float_info.max * 2 = inf + base_delay = sys.float_info.max - 1 + max_delay = sys.float_info.max + max_attempts = 2**12 + policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) + schedule = list(policy.new_schedule()) + for number in schedule: + self.assertLessEqual(number, sys.float_info.max) + + def test_schedule_with_jitter(self): + """ + Test to verify jitter is added properly and is always between -/+ 15%. + + @since 3.18 + @jira_ticket PYTHON-1065 + """ + for i in range(100): + base_delay = float(randint(2, 5)) + max_delay = (base_delay - 1) * 100.0 + ep = ExponentialReconnectionPolicy(base_delay, max_delay, max_attempts=64) + schedule = ep.new_schedule() + for i in range(64): + exp_delay = min(base_delay * (2 ** i), max_delay) + min_jitter_delay = max(base_delay, exp_delay*85/100) + max_jitter_delay = min(max_delay, exp_delay*115/100) + delay = next(schedule) + self._assert_between(delay, min_jitter_delay, max_jitter_delay) + ONE = ConsistencyLevel.ONE @@ -916,13 +1042,13 @@ def test_unavailable(self): retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) self.assertEqual(consistency, None) retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=10000, alive_replicas=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) + self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) self.assertEqual(consistency, None) @@ -1072,7 +1198,7 @@ def test_write_timeout(self): query=None, consistency=ONE, write_type=write_type, required_responses=1, received_responses=2, retry_num=0) self.assertEqual(retry, RetryPolicy.IGNORE) - # retrhow if we can't be sure we have a replica + # rethrow if we can't be sure we have a replica retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=write_type, required_responses=1, received_responses=0, retry_num=0) @@ -1113,3 +1239,262 @@ def test_unavailable(self): query=None, consistency=ONE, required_replicas=3, alive_replicas=1, retry_num=0) self.assertEqual(retry, RetryPolicy.RETRY) self.assertEqual(consistency, ConsistencyLevel.ONE) + + +class WhiteListRoundRobinPolicyTest(unittest.TestCase): + + def test_hosts_with_hostname(self): + hosts = ['localhost'] + policy = WhiteListRoundRobinPolicy(hosts) + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy) + policy.populate(None, [host]) + + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), [host]) + + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + + +class AddressTranslatorTest(unittest.TestCase): + + def test_identity_translator(self): + IdentityTranslator() + + @patch('socket.getfqdn', return_value='localhost') + def test_ec2_multi_region_translator(self, *_): + ec2t = EC2MultiRegionTranslator() + addr = '127.0.0.1' + translated = ec2t.translate(addr) + self.assertIsNot(translated, addr) # verifies that the resolver path is followed + self.assertEqual(translated, addr) # and that it resolves to the same address + + +class HostFilterPolicyInitTest(unittest.TestCase): + + def setUp(self): + self.child_policy, self.predicate = (Mock(name='child_policy'), + Mock(name='predicate')) + + def _check_init(self, hfp): + self.assertIs(hfp._child_policy, self.child_policy) + self.assertIsInstance(hfp._hosts_lock, LockType) + + # we can't use a simple assertIs because we wrap the function + arg0, arg1 = Mock(name='arg0'), Mock(name='arg1') + hfp.predicate(arg0) + hfp.predicate(arg1) + self.predicate.assert_has_calls([call(arg0), call(arg1)]) + + def test_init_arg_order(self): + self._check_init(HostFilterPolicy(self.child_policy, self.predicate)) + + def test_init_kwargs(self): + self._check_init(HostFilterPolicy( + predicate=self.predicate, child_policy=self.child_policy + )) + + def test_immutable_predicate(self): + if sys.version_info >= (3, 11): + expected_message_regex = "has no setter" + else: + expected_message_regex = "can't set attribute" + hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'), + predicate=Mock(name='predicate')) + with self.assertRaisesRegex(AttributeError, expected_message_regex): + hfp.predicate = object() + + +class HostFilterPolicyDeferralTest(unittest.TestCase): + + def setUp(self): + self.passthrough_hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=Mock(name='passthrough_predicate', + return_value=True) + ) + self.filterall_hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=Mock(name='filterall_predicate', + return_value=False) + ) + + def _check_host_triggered_method(self, policy, name): + arg, kwarg = Mock(name='arg'), Mock(name='kwarg') + method, child_policy_method = (getattr(policy, name), + getattr(policy._child_policy, name)) + + result = method(arg, kw=kwarg) + + # method calls the child policy's method... + child_policy_method.assert_called_once_with(arg, kw=kwarg) + # and returns its return value + self.assertIs(result, child_policy_method.return_value) + + def test_defer_on_up_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_up') + + def test_defer_on_down_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_down') + + def test_defer_on_add_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_add') + + def test_defer_on_remove_to_child_policy(self): + self._check_host_triggered_method(self.passthrough_hfp, 'on_remove') + + def test_filtered_host_on_up_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_up') + + def test_filtered_host_on_down_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_down') + + def test_filtered_host_on_add_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_add') + + def test_filtered_host_on_remove_doesnt_call_child_policy(self): + self._check_host_triggered_method(self.filterall_hfp, 'on_remove') + + def _check_check_supported_deferral(self, policy): + policy.check_supported() + policy._child_policy.check_supported.assert_called_once() + + def test_check_supported_defers_to_child(self): + self._check_check_supported_deferral(self.passthrough_hfp) + + def test_check_supported_defers_to_child_when_predicate_filtered(self): + self._check_check_supported_deferral(self.filterall_hfp) + + +class HostFilterPolicyDistanceTest(unittest.TestCase): + + def setUp(self): + self.hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy', distance=Mock(name='distance')), + predicate=lambda host: host.address == 'acceptme' + ) + self.ignored_host = Host(DefaultEndPoint('ignoreme'), conviction_policy_factory=Mock()) + self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) + + def test_ignored_with_filter(self): + self.assertEqual(self.hfp.distance(self.ignored_host), + HostDistance.IGNORED) + self.assertNotEqual(self.hfp.distance(self.accepted_host), + HostDistance.IGNORED) + + def test_accepted_filter_defers_to_child_policy(self): + self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock() + + # getting the distance for an ignored host shouldn't affect subsequent results + self.hfp.distance(self.ignored_host) + # first call of _child_policy with count() side effect + self.assertEqual(self.hfp.distance(self.accepted_host), distances[0]) + # second call of _child_policy with count() side effect + self.assertEqual(self.hfp.distance(self.accepted_host), distances[1]) + + +class HostFilterPolicyPopulateTest(unittest.TestCase): + + def test_populate_deferred_to_child(self): + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=lambda host: True + ) + mock_cluster, hosts = (Mock(name='cluster'), + ['host1', 'host2', 'host3']) + hfp.populate(mock_cluster, hosts) + hfp._child_policy.populate.assert_called_once_with( + cluster=mock_cluster, + hosts=hosts + ) + + def test_child_is_populated_with_filtered_hosts(self): + hfp = HostFilterPolicy( + child_policy=Mock(name='child_policy'), + predicate=lambda host: False + ) + mock_cluster, hosts = (Mock(name='cluster'), + ['acceptme0', 'acceptme1']) + hfp.populate(mock_cluster, hosts) + hfp._child_policy.populate.assert_called_once() + self.assertEqual( + hfp._child_policy.populate.call_args[1]['hosts'], + ['acceptme0', 'acceptme1'] + ) + + +class HostFilterPolicyQueryPlanTest(unittest.TestCase): + + def test_query_plan_deferred_to_child(self): + child_policy = Mock( + name='child_policy', + make_query_plan=Mock( + return_value=[object(), object(), object()] + ) + ) + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: True + ) + working_keyspace, query = (Mock(name='working_keyspace'), + Mock(name='query')) + qp = list(hfp.make_query_plan(working_keyspace=working_keyspace, + query=query)) + hfp._child_policy.make_query_plan.assert_called_once_with( + working_keyspace=working_keyspace, + query=query + ) + self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value) + + def test_wrap_token_aware(self): + cluster = Mock(spec=Cluster) + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + for host in hosts: + host.set_up() + + def get_replicas(keyspace, packed_key): + return hosts[:2] + + cluster.metadata.get_replicas.side_effect = get_replicas + + child_policy = TokenAwarePolicy(RoundRobinPolicy()) + + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: host.address != "127.0.0.1" and host.address != "127.0.0.4" + ) + hfp.populate(cluster, hosts) + + # We don't allow randomness for ordering the replicas in RoundRobin + hfp._child_policy._child_policy._position = 0 + + + mocked_query = Mock() + query_plan = hfp.make_query_plan("keyspace", mocked_query) + # First the not filtered replica, and then the rest of the allowed hosts ordered + query_plan = list(query_plan) + self.assertEqual(query_plan[0], Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy)) + self.assertEqual(set(query_plan[1:]),{Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)}) + + def test_create_whitelist(self): + cluster = Mock(spec=Cluster) + hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] + for host in hosts: + host.set_up() + + child_policy = RoundRobinPolicy() + + hfp = HostFilterPolicy( + child_policy=child_policy, + predicate=lambda host: host.address == "127.0.0.1" or host.address == "127.0.0.4" + ) + hfp.populate(cluster, hosts) + + # We don't allow randomness for ordering the replicas in RoundRobin + hfp._child_policy._position = 0 + + mocked_query = Mock() + query_plan = hfp.make_query_plan("keyspace", mocked_query) + # Only the filtered replicas should be allowed + self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py new file mode 100644 index 0000000000..08516eba9e --- /dev/null +++ b/tests/unit/test_protocol.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from unittest.mock import Mock + +from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra.protocol import ( + PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, + _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, + _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, + BatchMessage +) +from cassandra.query import BatchType +from cassandra.marshal import uint32_unpack +from cassandra.cluster import ContinuousPagingOptions + + +class MessageTest(unittest.TestCase): + + def test_prepare_message(self): + """ + Test to check the appropriate calls are made + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = PrepareMessage("a") + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) + + io.reset_mock() + message.send_body(io, 5) + + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) + + def test_execute_message(self): + message = ExecuteMessage('1', [], 4) + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) + + io.reset_mock() + message.result_metadata_id = 'foo' + message.send_body(io, 5) + + self._check_calls(io, [(b'\x00\x01',), (b'1',), + (b'\x00\x03',), (b'foo',), + (b'\x00\x04',), + (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) + + def test_query_message(self): + """ + Test to check the appropriate calls are made + + @since 3.9 + @jira_ticket PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = QueryMessage("a", 3) + io = Mock() + + message.send_body(io, 4) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) + + io.reset_mock() + message.send_body(io, 5) + self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) + + def _check_calls(self, io, expected): + self.assertEqual( + tuple(c[1] for c in io.write.mock_calls), + tuple(expected) + ) + + def test_continuous_paging(self): + """ + Test to check continuous paging throws an Exception if it's not supported and the correct values + are written to the buffer if the option is enabled. + + @since DSE 2.0b3 GRAPH 1.0b1 + @jira_ticket PYTHON-694 + @expected_result the values are correctly written + + @test_category connection + """ + max_pages = 4 + max_pages_per_second = 3 + continuous_paging_options = ContinuousPagingOptions(max_pages=max_pages, + max_pages_per_second=max_pages_per_second) + message = QueryMessage("a", 3, continuous_paging_options=continuous_paging_options) + io = Mock() + for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS + if not ProtocolVersion.has_continuous_paging_support(version)]: + self.assertRaises(UnsupportedOperation, message.send_body, io, version) + + io.reset_mock() + message.send_body(io, ProtocolVersion.DSE_V1) + + # continuous paging adds two write calls to the buffer + self.assertEqual(len(io.write.mock_calls), 6) + # Check that the appropriate flag is set to True + self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 0) + self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG, 0) + self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_PAGING_STATE_FLAG, 0) + self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG, _PAGING_OPTIONS_FLAG) + + # Test max_pages and max_pages_per_second are correctly written + self.assertEqual(uint32_unpack(io.write.mock_calls[4][1][0]), max_pages) + self.assertEqual(uint32_unpack(io.write.mock_calls[5][1][0]), max_pages_per_second) + + def test_prepare_flag(self): + """ + Test to check the prepare flag is properly set, This should only happen for V5 at the moment. + + @since 3.9 + @jira_ticket PYTHON-694, PYTHON-713 + @expected_result the values are correctly written + + @test_category connection + """ + message = PrepareMessage("a") + io = Mock() + for version in ProtocolVersion.SUPPORTED_VERSIONS: + message.send_body(io, version) + if ProtocolVersion.uses_prepare_flags(version): + self.assertEqual(len(io.write.mock_calls), 3) + else: + self.assertEqual(len(io.write.mock_calls), 2) + io.reset_mock() + + def test_prepare_flag_with_keyspace(self): + message = PrepareMessage("a", keyspace='ks') + io = Mock() + + for version in ProtocolVersion.SUPPORTED_VERSIONS: + if ProtocolVersion.uses_keyspace_flag(version): + message.send_body(io, version) + self._check_calls(io, [ + (b'\x00\x00\x00\x01',), + (b'a',), + (b'\x00\x00\x00\x01',), + (b'\x00\x02',), + (b'ks',), + ]) + else: + with self.assertRaises(UnsupportedOperation): + message.send_body(io, version) + io.reset_mock() + + def test_keyspace_flag_raises_before_v5(self): + keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks') + io = Mock(name='io') + + with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'): + keyspace_message.send_body(io, protocol_version=4) + io.assert_not_called() + + def test_keyspace_written_with_length(self): + io = Mock(name='io') + base_expected = [ + (b'\x00\x00\x00\x01',), + (b'a',), + (b'\x00\x03',), + (b'\x00\x00\x00\x80',), # options w/ keyspace flag + ] + + QueryMessage('a', consistency_level=3, keyspace='ks').send_body( + io, protocol_version=5 + ) + self._check_calls(io, base_expected + [ + (b'\x00\x02',), # length of keyspace string + (b'ks',), + ]) + + io.reset_mock() + + QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( + io, protocol_version=5 + ) + self._check_calls(io, base_expected + [ + (b'\x00\x08',), # length of keyspace string + (b'keyspace',), + ]) + + def test_batch_message_with_keyspace(self): + self.maxDiff = None + io = Mock(name='io') + batch = BatchMessage( + batch_type=BatchType.LOGGED, + queries=((False, 'stmt a', ('param a',)), + (False, 'stmt b', ('param b',)), + (False, 'stmt c', ('param c',)) + ), + consistency_level=3, + keyspace='ks' + ) + batch.send_body(io, protocol_version=5) + self._check_calls(io, + ((b'\x00',), (b'\x00\x03',), (b'\x00',), + (b'\x00\x00\x00\x06',), (b'stmt a',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), + (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), + (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), + (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), + (b'\x00\x03',), + (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) + ) diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py new file mode 100644 index 0000000000..2e87da389b --- /dev/null +++ b/tests/unit/test_query.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from cassandra.query import BatchStatement, SimpleStatement + + +class BatchStatementTest(unittest.TestCase): + # TODO: this suite could be expanded; for now just adding a test covering a PR + + def test_clear(self): + keyspace = 'keyspace' + routing_key = 'routing_key' + custom_payload = {'key': b'value'} + + ss = SimpleStatement('whatever', keyspace=keyspace, routing_key=routing_key, custom_payload=custom_payload) + + batch = BatchStatement() + batch.add(ss) + + self.assertTrue(batch._statements_and_parameters) + self.assertEqual(batch.keyspace, keyspace) + self.assertEqual(batch.routing_key, routing_key) + self.assertEqual(batch.custom_payload, custom_payload) + + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add(ss) + + def test_clear_empty(self): + batch = BatchStatement() + batch.clear() + self.assertFalse(batch._statements_and_parameters) + self.assertIsNone(batch.keyspace) + self.assertIsNone(batch.routing_key) + self.assertFalse(batch.custom_payload) + + batch.add('something') + + def test_add_all(self): + batch = BatchStatement() + statements = ['%s'] * 10 + parameters = [(i,) for i in range(10)] + batch.add_all(statements, parameters) + bound_statements = [t[1] for t in batch._statements_and_parameters] + str_parameters = [str(i) for i in range(10)] + self.assertEqual(bound_statements, str_parameters) + + def test_len(self): + for n in 0, 10, 100: + batch = BatchStatement() + batch.add_all(statements=['%s'] * n, + parameters=[(i,) for i in range(n)]) + self.assertEqual(len(batch), n) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 8e047d91ae..f9d32780de 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,22 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from mock import Mock, MagicMock, ANY +from collections import deque +from threading import RLock +from unittest.mock import Mock, MagicMock, ANY -from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable +from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, OverloadedErrorMessage, IsBootstrappingErrorMessage, PreparedQueryNotFound, PrepareMessage, RESULT_KIND_ROWS, RESULT_KIND_SET_KEYSPACE, - RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler) + RESULT_KIND_SCHEMA_CHANGE, RESULT_KIND_PREPARED, + ProtocolHandler) from cassandra.policies import RetryPolicy from cassandra.pool import NoConnectionsAvailable from cassandra.query import SimpleStatement @@ -36,12 +38,20 @@ class ResponseFutureTests(unittest.TestCase): def make_basic_session(self): - return Mock(spec=Session, row_factory=lambda *x: list(x)) + s = Mock(spec=Session) + s.row_factory = lambda col_names, rows: [(col_names, rows)] + return s + + def make_pool(self): + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.return_value = [Mock(), Mock()] + return pool def make_session(self): session = self.make_basic_session() - session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] - session._pools.get.return_value.is_shutdown = False + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools.get.return_value = self.make_pool() return session def make_response_future(self, session): @@ -49,12 +59,12 @@ def make_response_future(self, session): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) return ResponseFuture(session, message, query, 1) - def make_mock_response(self, results): - return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, results=results, paging_state=None) + def make_mock_response(self, col_names, rows): + return Mock(spec=ResultMessage, kind=RESULT_KIND_ROWS, column_names=col_names, parsed_rows=rows, paging_state=None, col_types=None) def test_result_message(self): session = self.make_basic_session() - session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] pool = session._pools.get.return_value pool.is_shutdown = False @@ -67,11 +77,12 @@ def test_result_message(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) - rf._set_result(self.make_mock_response([{'col': 'val'}])) - result = rf.result() - self.assertEqual(result, [{'col': 'val'}]) + expected_result = (object(), object()) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) + result = rf.result()[0] + self.assertEqual(result, expected_result) def test_unknown_result_class(self): session = self.make_session() @@ -81,7 +92,7 @@ def test_unknown_result_class(self): rf = self.make_response_future(session) rf.send_request() - rf._set_result(object()) + rf._set_result(None, None, None, object()) self.assertRaises(ConnectionException, rf.result) def test_set_keyspace_result(self): @@ -92,7 +103,7 @@ def test_set_keyspace_result(self): result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, results="keyspace1") - rf._set_result(result) + rf._set_result(None, None, None, result) rf._set_keyspace_completed({}) self.assertFalse(rf.result()) @@ -105,45 +116,77 @@ def test_schema_change_result(self): 'keyspace': "keyspace1", "table": "table1"} result = Mock(spec=ResultMessage, kind=RESULT_KIND_SCHEMA_CHANGE, - results=event_results) - rf._set_result(result) - session.submit.assert_called_once_with(ANY, ANY, rf, **event_results) + schema_change_event=event_results) + connection = Mock() + rf._set_result(None, connection, None, result) + session.submit.assert_called_once_with(ANY, ANY, rf, connection, **event_results) def test_other_result_message_kind(self): session = self.make_session() rf = self.make_response_future(session) rf.send_request() - result = [1, 2, 3] - rf._set_result(Mock(spec=ResultMessage, kind=999, results=result)) - self.assertListEqual(list(rf.result()), result) + result = Mock(spec=ResultMessage, kind=999, results=[1, 2, 3]) + rf._set_result(None, None, None, result) + self.assertEqual(rf.result()[0], result) + + def test_heartbeat_defunct_deadlock(self): + """ + Heartbeat defuncts all connections and clears request queues. Response future times out and even + if it has been removed from request queue, timeout exception must be thrown. Otherwise event loop + will deadlock on eventual ResponseFuture.result() call. + + PYTHON-1044 + """ + + connection = MagicMock(spec=Connection) + connection._requests = {} + + pool = Mock() + pool.is_shutdown = False + pool.borrow_connection.return_value = [connection, 1] + + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(), Mock()] + session._pools.get.return_value = pool + + query = SimpleStatement("SELECT * FROM foo") + message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1) + rf.send_request() + + # Simulate Connection.error_all_requests() after heartbeat defuncts + connection._requests = {} + + # Simulate ResponseFuture timing out + rf._on_timeout() + self.assertRaisesRegex(OperationTimedOut, "Connection defunct by heartbeat", rf.result) def test_read_timeout_error_message(self): session = self.make_session() query = SimpleStatement("SELECT * FROM foo") - query.retry_policy = Mock() - query.retry_policy.on_read_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1) rf.send_request() - result = Mock(spec=ReadTimeoutErrorMessage, info={}) - rf._set_result(result) + result = Mock(spec=ReadTimeoutErrorMessage, info={"data_retrieved": "", "required_responses":2, + "received_responses":1, "consistency": 1}) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) def test_write_timeout_error_message(self): session = self.make_session() query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_write_timeout.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1) rf.send_request() - result = Mock(spec=WriteTimeoutErrorMessage, info={}) - rf._set_result(result) + result = Mock(spec=WriteTimeoutErrorMessage, info={"write_type": 1, "required_responses":2, + "received_responses":1, "consistency": 1}) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) def test_unavailable_error_message(self): @@ -154,24 +197,47 @@ def test_unavailable_error_message(self): message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 rf.send_request() - result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) + def test_request_error_with_prepare_message(self): + session = self.make_session() + query = SimpleStatement("SELECT * FROM foobar") + retry_policy = Mock() + retry_policy.on_request_error.return_value = (RetryPolicy.RETHROW, None) + message = PrepareMessage(query=query) + + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) + rf._query_retries = 1 + rf.send_request() + result = Mock(spec=OverloadedErrorMessage) + result.to_exception.return_value = result + rf._set_result(None, None, None, result) + self.assertIsInstance(rf._final_exception, OverloadedErrorMessage) + + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) + rf._query_retries = 1 + rf.send_request() + result = Mock(spec=ConnectionException) + rf._set_result(None, None, None, result) + self.assertIsInstance(rf._final_exception, ConnectionException) + def test_retry_policy_says_ignore(self): session = self.make_session() query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.IGNORE, None) + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertFalse(rf.result()) def test_retry_policy_says_retry(self): @@ -179,37 +245,39 @@ def test_retry_policy_says_retry(self): pool = session._pools.get.return_value query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.QUORUM) connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 1) - rf = ResponseFuture(session, message, query, 1) + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.RETRY, ConsistencyLevel.ONE) + + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) - session.submit.assert_called_once_with(rf._retry_task, True) + session.submit.assert_called_once_with(rf._retry_task, True, host) self.assertEqual(1, rf._query_retries) connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) # simulate the executor running this - rf._retry_task(True) + rf._retry_task(True, host) # it should try again with the same host since this was # an UnavailableException - rf.session._pools.get.assert_called_with('ip1') + rf.session._pools.get.assert_called_with(host) pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): session = self.make_session() @@ -224,25 +292,26 @@ def test_retry_with_different_host(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) - session.submit.assert_called_once_with(rf._retry_task, False) - # query_retries does not get incremented for Overloaded/Bootstrapping errors - self.assertEqual(0, rf._query_retries) + session.submit.assert_called_once_with(rf._retry_task, False, host) + # query_retries does get incremented for Overloaded/Bootstrapping errors (since 3.18) + self.assertEqual(1, rf._query_retries) connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) # simulate the executor running this - rf._retry_task(False) + rf._retry_task(False, host) # it should try with a different host rf.session._pools.get.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) # the consistency level should be the same self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) @@ -258,27 +327,28 @@ def test_all_retries_fail(self): rf.session._pools.get.assert_called_once_with('ip1') result = Mock(spec=IsBootstrappingErrorMessage, info={}) - rf._set_result(result) + host = Mock() + rf._set_result(host, None, None, result) # simulate the executor running this - session.submit.assert_called_once_with(rf._retry_task, False) - rf._retry_task(False) + session.submit.assert_called_once_with(rf._retry_task, False, host) + rf._retry_task(False, host) # it should try with a different host rf.session._pools.get.assert_called_with('ip2') result = Mock(spec=IsBootstrappingErrorMessage, info={}) - rf._set_result(result) + rf._set_result(host, None, None, result) # simulate the executor running this - session.submit.assert_called_with(rf._retry_task, False) - rf._retry_task(False) + session.submit.assert_called_with(rf._retry_task, False, host) + rf._retry_task(False, host) self.assertRaises(NoHostAvailable, rf.result) def test_all_pools_shutdown(self): session = self.make_basic_session() - session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] session._pools.get.return_value.is_shutdown = True rf = ResponseFuture(session, Mock(), Mock(), 1) @@ -287,21 +357,26 @@ def test_all_pools_shutdown(self): def test_first_pool_shutdown(self): session = self.make_basic_session() - session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] # first return a pool with is_shutdown=True, then is_shutdown=False - session._pools.get.side_effect = [Mock(is_shutdown=True), Mock(is_shutdown=False)] + pool_shutdown = self.make_pool() + pool_shutdown.is_shutdown = True + pool_ok = self.make_pool() + pool_ok.is_shutdown = True + session._pools.get.side_effect = [pool_shutdown, pool_ok] rf = self.make_response_future(session) rf.send_request() - rf._set_result(self.make_mock_response([{'col': 'val'}])) + expected_result = (object(), object()) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) - result = rf.result() - self.assertEqual(result, [{'col': 'val'}]) + result = rf.result()[0] + self.assertEqual(result, expected_result) def test_timeout_getting_connection_from_pool(self): session = self.make_basic_session() - session._load_balancer.make_query_plan.return_value = ['ip1', 'ip2'] + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] # the first pool will raise an exception on borrow_connection() exc = NoConnectionsAvailable() @@ -318,8 +393,9 @@ def test_timeout_getting_connection_from_pool(self): rf = self.make_response_future(session) rf.send_request() - rf._set_result(self.make_mock_response([{'col': 'val'}])) - self.assertEqual(rf.result(), [{'col': 'val'}]) + expected_result = (object(), object()) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) + self.assertEqual(rf.result()[0], expected_result) # make sure the exception is recorded correctly self.assertEqual(rf._errors, {'ip1': exc}) @@ -330,20 +406,20 @@ def test_callback(self): rf.send_request() callback = Mock() - expected_result = [{'col': 'val'}] + expected_result = (object(), object()) arg = "positional" kwargs = {'one': 1, 'two': 2} rf.add_callback(callback, arg, **kwargs) - rf._set_result(self.make_mock_response(expected_result)) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) - result = rf.result() + result = rf.result()[0] self.assertEqual(result, expected_result) - callback.assert_called_once_with(expected_result, arg, **kwargs) + callback.assert_called_once_with([expected_result], arg, **kwargs) # this should get called immediately now that the result is set - rf.add_callback(self.assertEqual, [{'col': 'val'}]) + rf.add_callback(self.assertEqual, [expected_result]) def test_errback(self): session = self.make_session() @@ -352,17 +428,18 @@ def test_errback(self): pool.borrow_connection.return_value = (connection, 1) query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 rf.send_request() rf.add_errback(self.assertIsInstance, Exception) - result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + result.to_exception.return_value = Exception() + + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) # this should get called immediately now that the error is set @@ -374,7 +451,7 @@ def test_multiple_callbacks(self): rf.send_request() callback = Mock() - expected_result = [{'col': 'val'}] + expected_result = (object(), object()) arg = "positional" kwargs = {'one': 1, 'two': 2} rf.add_callback(callback, arg, **kwargs) @@ -384,13 +461,13 @@ def test_multiple_callbacks(self): kwargs2 = {'three': 3, 'four': 4} rf.add_callback(callback2, arg2, **kwargs2) - rf._set_result(self.make_mock_response(expected_result)) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) - result = rf.result() + result = rf.result()[0] self.assertEqual(result, expected_result) - callback.assert_called_once_with(expected_result, arg, **kwargs) - callback2.assert_called_once_with(expected_result, arg2, **kwargs2) + callback.assert_called_once_with([expected_result], arg, **kwargs) + callback2.assert_called_once_with([expected_result], arg2, **kwargs2) def test_multiple_errbacks(self): session = self.make_session() @@ -399,11 +476,11 @@ def test_multiple_errbacks(self): pool.borrow_connection.return_value = (connection, 1) query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) - rf = ResponseFuture(session, message, query, 1) + retry_policy = Mock() + retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) + rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf.send_request() callback = Mock() @@ -417,9 +494,10 @@ def test_multiple_errbacks(self): rf.add_errback(callback2, arg2, **kwargs2) expected_exception = Unavailable("message", 1, 2, 3) - result = Mock(spec=UnavailableErrorMessage, info={'something': 'here'}) + result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) result.to_exception.return_value = expected_exception - rf._set_result(result) + rf._set_result(None, None, None, result) + rf._event.set() self.assertRaises(Exception, rf.result) callback.assert_called_once_with(expected_exception, arg, **kwargs) @@ -428,20 +506,21 @@ def test_multiple_errbacks(self): def test_add_callbacks(self): session = self.make_session() query = SimpleStatement("INSERT INFO foo (a, b) VALUES (1, 2)") - query.retry_policy = Mock() - query.retry_policy.on_unavailable.return_value = (RetryPolicy.RETHROW, None) message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) # test errback rf = ResponseFuture(session, message, query, 1) + rf._query_retries = 1 rf.send_request() rf.add_callbacks( callback=self.assertEqual, callback_args=([{'col': 'val'}],), errback=self.assertIsInstance, errback_args=(Exception,)) - result = Mock(spec=UnavailableErrorMessage, info={}) - rf._set_result(result) + result = Mock(spec=UnavailableErrorMessage, + info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) + result.to_exception.return_value = Exception() + rf._set_result(None, None, None, result) self.assertRaises(Exception, rf.result) # test callback @@ -449,17 +528,17 @@ def test_add_callbacks(self): rf.send_request() callback = Mock() - expected_result = [{'col': 'val'}] + expected_result = (object(), object()) arg = "positional" kwargs = {'one': 1, 'two': 2} rf.add_callbacks( callback=callback, callback_args=(arg,), callback_kwargs=kwargs, errback=self.assertIsInstance, errback_args=(Exception,)) - rf._set_result(self.make_mock_response(expected_result)) - self.assertEqual(rf.result(), expected_result) + rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) + self.assertEqual(rf.result()[0], expected_result) - callback.assert_called_once_with(expected_result, arg, **kwargs) + callback.assert_called_once_with([expected_result], arg, **kwargs) def test_prepared_query_not_found(self): session = self.make_session() @@ -470,6 +549,7 @@ def test_prepared_query_not_found(self): rf = self.make_response_future(session) rf.send_request() + session.cluster.protocol_version = ProtocolVersion.V4 session.cluster._prepared_statements = MagicMock(dict) prepared_statement = session.cluster._prepared_statements.__getitem__.return_value prepared_statement.query_string = "SELECT * FROM foobar" @@ -477,13 +557,13 @@ def test_prepared_query_not_found(self): rf._connection.keyspace = "FooKeyspace" result = Mock(spec=PreparedQueryNotFound, info='a' * 16) - rf._set_result(result) + rf._set_result(None, None, None, result) - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) args, kwargs = session.submit.call_args - self.assertEqual(rf._reprepare, args[-2]) - self.assertIsInstance(args[-1], PrepareMessage) - self.assertEqual(args[-1].query, "SELECT * FROM foobar") + self.assertEqual(rf._reprepare, args[-5]) + self.assertIsInstance(args[-4], PrepareMessage) + self.assertEqual(args[-4].query, "SELECT * FROM foobar") def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() @@ -494,6 +574,7 @@ def test_prepared_query_not_found_bad_keyspace(self): rf = self.make_response_future(session) rf.send_request() + session.cluster.protocol_version = ProtocolVersion.V4 session.cluster._prepared_statements = MagicMock(dict) prepared_statement = session.cluster._prepared_statements.__getitem__.return_value prepared_statement.query_string = "SELECT * FROM foobar" @@ -501,5 +582,52 @@ def test_prepared_query_not_found_bad_keyspace(self): rf._connection.keyspace = "BarKeyspace" result = Mock(spec=PreparedQueryNotFound, info='a' * 16) - rf._set_result(result) + rf._set_result(None, None, None, result) self.assertRaises(ValueError, rf.result) + + def test_repeat_orig_query_after_succesful_reprepare(self): + query_id = b'abc123' # Just a random binary string so we don't hit id mismatch exception + session = self.make_session() + rf = self.make_response_future(session) + + response = Mock(spec=ResultMessage, + kind=RESULT_KIND_PREPARED, + result_metadata_id='foo') + response.results = (None, None, None, None, None) + response.query_id = query_id + + rf._query = Mock(return_value=True) + rf._execute_after_prepare('host', None, None, response) + rf._query.assert_called_once_with('host') + + rf.prepared_statement = Mock() + rf.prepared_statement.query_id = query_id + rf._query = Mock(return_value=True) + rf._execute_after_prepare('host', None, None, response) + rf._query.assert_called_once_with('host') + + def test_timeout_does_not_release_stream_id(self): + """ + Make sure that stream ID is not reused immediately after client-side + timeout. Otherwise, a new request could reuse the stream ID and would + risk getting a response for the old, timed out query. + """ + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [Mock(endpoint='ip1'), Mock(endpoint='ip2')] + pool = self.make_pool() + session._pools.get.return_value = pool + connection = Mock(spec=Connection, lock=RLock(), _requests={}, request_ids=deque(), + orphaned_request_ids=set(), orphaned_threshold=256) + pool.borrow_connection.return_value = (connection, 1) + + rf = self.make_response_future(session) + rf.send_request() + + connection._requests[1] = (connection._handle_options_response, ProtocolHandler.decode_message, []) + + rf._on_timeout() + pool.return_connection.assert_called_once_with(connection, stream_was_orphaned=True) + self.assertRaisesRegex(OperationTimedOut, "Client request timeout", rf.result) + + assert len(connection.request_ids) == 0, \ + "Request IDs should be empty but it's not: {}".format(connection.request_ids) diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index 2deeb30f75..340169d198 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -1,27 +1,24 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from cassandra.query import named_tuple_factory, dict_factory, tuple_factory - -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa -from mock import Mock, PropertyMock -import warnings +import unittest +from unittest.mock import Mock, PropertyMock, patch from cassandra.cluster import ResultSet +from cassandra.query import named_tuple_factory, dict_factory, tuple_factory class ResultSetTests(unittest.TestCase): @@ -34,7 +31,7 @@ def test_iter_non_paged(self): def test_iter_paged(self): expected = list(range(10)) - response_future = Mock(has_more_pages=True) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) itr = iter(rs) @@ -42,6 +39,19 @@ def test_iter_paged(self): type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # after init to avoid side effects being consumed by init self.assertListEqual(list(itr), expected) + def test_iter_paged_with_empty_pages(self): + expected = list(range(10)) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) + response_future.result.side_effect = [ + ResultSet(Mock(), []), + ResultSet(Mock(), [0, 1, 2, 3, 4]), + ResultSet(Mock(), []), + ResultSet(Mock(), [5, 6, 7, 8, 9]), + ] + rs = ResultSet(response_future, []) + itr = iter(rs) + self.assertListEqual(list(itr), expected) + def test_list_non_paged(self): # list access on RS for backwards-compatibility expected = list(range(10)) @@ -53,7 +63,7 @@ def test_list_non_paged(self): def test_list_paged(self): # list access on RS for backwards-compatibility expected = list(range(10)) - response_future = Mock(has_more_pages=True) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) # this is brittle, depends on internal impl details. Would like to find a better way @@ -86,7 +96,7 @@ def test_iterate_then_index(self): self.assertFalse(list(rs)) # RuntimeError if indexing during or after pages - response_future = Mock(has_more_pages=True) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) @@ -112,14 +122,14 @@ def test_index_list_mode(self): # index access before iteration causes list to be materialized self.assertEqual(rs[0], expected[0]) - # resusable iteration + # reusable iteration self.assertListEqual(list(rs), expected) self.assertListEqual(list(rs), expected) self.assertTrue(rs) # pages - response_future = Mock(has_more_pages=True) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) # this is brittle, depends on internal impl details. Would like to find a better way @@ -127,7 +137,7 @@ def test_index_list_mode(self): # index access before iteration causes list to be materialized self.assertEqual(rs[0], expected[0]) self.assertEqual(rs[9], expected[9]) - # resusable iteration + # reusable iteration self.assertListEqual(list(rs), expected) self.assertListEqual(list(rs), expected) @@ -147,7 +157,7 @@ def test_eq(self): self.assertTrue(rs) # pages - response_future = Mock(has_more_pages=True) + response_future = Mock(has_more_pages=True, _continuous_paging_session=None) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) @@ -188,3 +198,30 @@ def test_was_applied(self): for applied in (True, False): rs = ResultSet(Mock(row_factory=row_factory), [{'[applied]': applied}]) self.assertEqual(rs.was_applied, applied) + + def test_one(self): + # no pages + first, second = Mock(), Mock() + rs = ResultSet(Mock(has_more_pages=False), [first, second]) + + self.assertEqual(rs.one(), first) + + def test_all(self): + first, second = Mock(), Mock() + rs1 = ResultSet(Mock(has_more_pages=False), [first, second]) + rs2 = ResultSet(Mock(has_more_pages=False), [first, second]) + + self.assertEqual(rs1.all(), list(rs2)) + + @patch('cassandra.cluster.warn') + def test_indexing_deprecation(self, mocked_warn): + # normally we'd use catch_warnings to test this, but that doesn't work + # pre-Py3.0 for some reason + first, second = Mock(), Mock() + rs = ResultSet(Mock(has_more_pages=False), [first, second]) + self.assertEqual(rs[0], first) + self.assertEqual(len(mocked_warn.mock_calls), 1) + index_warning_args = tuple(mocked_warn.mock_calls[0])[1] + self.assertIn('indexing support will be removed in 4.0', + str(index_warning_args[0])) + self.assertIs(index_warning_args[1], DeprecationWarning) diff --git a/tests/unit/test_row_factories.py b/tests/unit/test_row_factories.py new file mode 100644 index 0000000000..0055497a54 --- /dev/null +++ b/tests/unit/test_row_factories.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cassandra.query import named_tuple_factory + +import logging +import warnings + +import sys + +from unittest import TestCase + + +log = logging.getLogger(__name__) + + +NAMEDTUPLE_CREATION_BUG = sys.version_info >= (3,) and sys.version_info < (3, 7) + +class TestNamedTupleFactory(TestCase): + + long_colnames, long_rows = ( + ['col{}'.format(x) for x in range(300)], + [ + ['value{}'.format(x) for x in range(300)] + for _ in range(100) + ] + ) + short_colnames, short_rows = ( + ['col{}'.format(x) for x in range(200)], + [ + ['value{}'.format(x) for x in range(200)] + for _ in range(100) + ] + ) + + def test_creation_warning_on_long_column_list(self): + """ + Reproduces the failure described in PYTHON-893 + + @since 3.15 + @jira_ticket PYTHON-893 + @expected_result creation fails on Python > 3 and < 3.7 + + @test_category row_factory + """ + if not NAMEDTUPLE_CREATION_BUG: + named_tuple_factory(self.long_colnames, self.long_rows) + return + + with warnings.catch_warnings(record=True) as w: + rows = named_tuple_factory(self.long_colnames, self.long_rows) + self.assertEqual(len(w), 1) + warning = w[0] + self.assertIn('pseudo_namedtuple_factory', str(warning)) + self.assertIn('3.7', str(warning)) + + for r in rows: + self.assertEqual(r.col0, self.long_rows[0][0]) + + def test_creation_no_warning_on_short_column_list(self): + """ + Tests that normal namedtuple row creation still works after PYTHON-893 fix + + @since 3.15 + @jira_ticket PYTHON-893 + @expected_result creates namedtuple-based Rows + + @test_category row_factory + """ + with warnings.catch_warnings(record=True) as w: + rows = named_tuple_factory(self.short_colnames, self.short_rows) + self.assertEqual(len(w), 0) + # check that this is a real namedtuple + self.assertTrue(hasattr(rows[0], '_fields')) + self.assertIsInstance(rows[0], tuple) diff --git a/tests/unit/test_segment.py b/tests/unit/test_segment.py new file mode 100644 index 0000000000..e94bcf9809 --- /dev/null +++ b/tests/unit/test_segment.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from io import BytesIO + +from cassandra import DriverException +from cassandra.segment import Segment, CrcException +from cassandra.connection import segment_codec_no_compression, segment_codec_lz4 + + +def to_bits(b): + return '{:08b}'.format(b) + + +class SegmentCodecTest(unittest.TestCase): + + small_msg = b'b' * 50 + max_msg = b'b' * Segment.MAX_PAYLOAD_LENGTH + large_msg = b'b' * (Segment.MAX_PAYLOAD_LENGTH + 1) + + @staticmethod + def _header_to_bits(data): + # unpack a header to bits + # data should be the little endian bytes sequence + if len(data) > 6: # compressed + data = data[:5] + bits = ''.join([to_bits(b) for b in reversed(data)]) + # return the compressed payload length, the uncompressed payload length, + # the self-contained flag and the padding as bits + return bits[23:40] + bits[6:23] + bits[5:6] + bits[:5] + else: # uncompressed + data = data[:3] + bits = ''.join([to_bits(b) for b in reversed(data)]) + # return the payload length, the self-contained flag and + # the padding as bits + return bits[7:24] + bits[6:7] + bits[:6] + + def test_encode_uncompressed_header(self): + buffer = BytesIO() + segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True) + self.assertEqual(buffer.tell(), 6) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + "00000000000110010" + "1" + "000000") + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_encode_compressed_header(self): + buffer = BytesIO() + compressed_length = len(segment_codec_lz4.compress(self.small_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) + + self.assertEqual(buffer.tell(), 8) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000") + + def test_encode_uncompressed_header_with_max_payload(self): + buffer = BytesIO() + segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, True) + self.assertEqual(buffer.tell(), 6) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + "11111111111111111" + "1" + "000000") + + def test_encode_header_fails_if_payload_too_big(self): + buffer = BytesIO() + for codec in [c for c in [segment_codec_no_compression, segment_codec_lz4] if c is not None]: + with self.assertRaises(DriverException): + codec.encode_header(buffer, len(self.large_msg), -1, False) + + def test_encode_uncompressed_header_not_self_contained_msg(self): + buffer = BytesIO() + # simulate the first chunk with the max size + segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, False) + self.assertEqual(buffer.tell(), 6) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + ("11111111111111111" + "0" # not self-contained + "000000")) + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_encode_compressed_header_with_max_payload(self): + buffer = BytesIO() + compressed_length = len(segment_codec_lz4.compress(self.max_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True) + self.assertEqual(buffer.tell(), 8) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000") + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_encode_compressed_header_not_self_contained_msg(self): + buffer = BytesIO() + # simulate the first chunk with the max size + compressed_length = len(segment_codec_lz4.compress(self.max_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False) + self.assertEqual(buffer.tell(), 8) + self.assertEqual( + self._header_to_bits(buffer.getvalue()), + ("{:017b}".format(compressed_length) + + "11111111111111111" + "0" # not self-contained + "00000")) + + def test_decode_uncompressed_header(self): + buffer = BytesIO() + segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True) + buffer.seek(0) + header = segment_codec_no_compression.decode_header(buffer) + self.assertEqual(header.uncompressed_payload_length, -1) + self.assertEqual(header.payload_length, len(self.small_msg)) + self.assertEqual(header.is_self_contained, True) + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_decode_compressed_header(self): + buffer = BytesIO() + compressed_length = len(segment_codec_lz4.compress(self.small_msg)) + segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) + buffer.seek(0) + header = segment_codec_lz4.decode_header(buffer) + self.assertEqual(header.uncompressed_payload_length, len(self.small_msg)) + self.assertEqual(header.payload_length, compressed_length) + self.assertEqual(header.is_self_contained, True) + + def test_decode_header_fails_if_corrupted(self): + buffer = BytesIO() + segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True) + # corrupt one byte + buffer.seek(buffer.tell()-1) + buffer.write(b'0') + buffer.seek(0) + + with self.assertRaises(CrcException): + segment_codec_no_compression.decode_header(buffer) + + def test_decode_uncompressed_self_contained_segment(self): + buffer = BytesIO() + segment_codec_no_compression.encode(buffer, self.small_msg) + + buffer.seek(0) + header = segment_codec_no_compression.decode_header(buffer) + segment = segment_codec_no_compression.decode(buffer, header) + + self.assertEqual(header.is_self_contained, True) + self.assertEqual(header.uncompressed_payload_length, -1) + self.assertEqual(header.payload_length, len(self.small_msg)) + self.assertEqual(segment.payload, self.small_msg) + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_decode_compressed_self_contained_segment(self): + buffer = BytesIO() + segment_codec_lz4.encode(buffer, self.small_msg) + + buffer.seek(0) + header = segment_codec_lz4.decode_header(buffer) + segment = segment_codec_lz4.decode(buffer, header) + + self.assertEqual(header.is_self_contained, True) + self.assertEqual(header.uncompressed_payload_length, len(self.small_msg)) + self.assertGreater(header.uncompressed_payload_length, header.payload_length) + self.assertEqual(segment.payload, self.small_msg) + + def test_decode_multi_segments(self): + buffer = BytesIO() + segment_codec_no_compression.encode(buffer, self.large_msg) + + buffer.seek(0) + # We should have 2 segments to read + headers = [] + segments = [] + headers.append(segment_codec_no_compression.decode_header(buffer)) + segments.append(segment_codec_no_compression.decode(buffer, headers[0])) + headers.append(segment_codec_no_compression.decode_header(buffer)) + segments.append(segment_codec_no_compression.decode(buffer, headers[1])) + + self.assertTrue(all([h.is_self_contained is False for h in headers])) + decoded_msg = segments[0].payload + segments[1].payload + self.assertEqual(decoded_msg, self.large_msg) + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_decode_fails_if_corrupted(self): + buffer = BytesIO() + segment_codec_lz4.encode(buffer, self.small_msg) + buffer.seek(buffer.tell()-1) + buffer.write(b'0') + buffer.seek(0) + header = segment_codec_lz4.decode_header(buffer) + with self.assertRaises(CrcException): + segment_codec_lz4.decode(buffer, header) + + @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') + def test_decode_tiny_msg_not_compressed(self): + buffer = BytesIO() + segment_codec_lz4.encode(buffer, b'b') + buffer.seek(0) + header = segment_codec_lz4.decode_header(buffer) + segment = segment_codec_lz4.decode(buffer, header) + self.assertEqual(header.uncompressed_payload_length, 0) + self.assertEqual(header.payload_length, 1) + self.assertEqual(segment.payload, b'b') diff --git a/tests/unit/test_sortedset.py b/tests/unit/test_sortedset.py index a2b264710f..875485f824 100644 --- a/tests/unit/test_sortedset.py +++ b/tests/unit/test_sortedset.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra.util import sortedset from cassandra.cqltypes import EMPTY +from datetime import datetime +from itertools import permutations class SortedSetTest(unittest.TestCase): def test_init(self): @@ -364,3 +365,38 @@ def test_reduce_pickle(self): s = pickle.dumps(ss) self.assertEqual(pickle.loads(s), ss) + def _test_uncomparable_types(self, items): + for perm in permutations(items): + ss = sortedset(perm) + s = set(perm) + self.assertEqual(s, ss) + self.assertEqual(ss, ss.union(s)) + for x in range(len(ss)): + subset = set(s) + for _ in range(x): + subset.pop() + self.assertEqual(ss.difference(subset), s.difference(subset)) + self.assertEqual(ss.intersection(subset), s.intersection(subset)) + for x in ss: + self.assertIn(x, ss) + ss.remove(x) + self.assertNotIn(x, ss) + + def test_uncomparable_types_with_tuples(self): + # PYTHON-1087 - make set handle uncomparable types + dt = datetime(2019, 5, 16) + items = (('samekey', 3, 1), + ('samekey', None, 0), + ('samekey', dt), + ("samekey", None, 2), + ("samekey", None, 1), + ('samekey', dt), + ('samekey', None, 0), + ("samekey", datetime.now())) + + self._test_uncomparable_types(items) + + def test_uncomparable_types_with_integers(self): + # PYTHON-1087 - make set handle uncomparable types + items = (None, 1, 2, 6, None, None, 92) + self._test_uncomparable_types(items) diff --git a/tests/unit/test_time_util.py b/tests/unit/test_time_util.py index 4455e58887..be5c984907 100644 --- a/tests/unit/test_time_util.py +++ b/tests/unit/test_time_util.py @@ -1,10 +1,12 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,10 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest from cassandra import marshal from cassandra import util diff --git a/tests/unit/test_timestamps.py b/tests/unit/test_timestamps.py new file mode 100644 index 0000000000..676cb6442a --- /dev/null +++ b/tests/unit/test_timestamps.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from cassandra import timestamps +from threading import Thread, Lock + + +class _TimestampTestMixin(object): + + @mock.patch('cassandra.timestamps.time') + def _call_and_check_results(self, + patched_time_module, + system_time_expected_stamp_pairs, + timestamp_generator=None): + """ + For each element in an iterable of (system_time, expected_timestamp) + pairs, call a :class:`cassandra.timestamps.MonotonicTimestampGenerator` + with system_times as the underlying time.time() result, then assert + that the result is expected_timestamp. Skips the check if + expected_timestamp is None. + """ + patched_time_module.time = mock.Mock() + system_times, expected_timestamps = zip(*system_time_expected_stamp_pairs) + + patched_time_module.time.side_effect = system_times + tsg = timestamp_generator or timestamps.MonotonicTimestampGenerator() + + for expected in expected_timestamps: + actual = tsg() + if expected is not None: + self.assertEqual(actual, expected) + + # assert we patched timestamps.time.time correctly + with self.assertRaises(StopIteration): + tsg() + + +class TestTimestampGeneratorOutput(unittest.TestCase, _TimestampTestMixin): + """ + Mock time.time and test the output of MonotonicTimestampGenerator.__call__ + given different patterns of changing results. + """ + + def test_timestamps_during_and_after_same_system_time(self): + """ + Timestamps should increase monotonically over repeated system time. + + Test that MonotonicTimestampGenerator's output increases by 1 when the + underlying system time is the same, then returns to normal when the + system time increases again. + + @since 3.8.0 + @expected_result Timestamps should increase monotonically over repeated system time. + @test_category timing + """ + self._call_and_check_results( + system_time_expected_stamp_pairs=( + (15.0, 15 * 1e6), + (15.0, 15 * 1e6 + 1), + (15.0, 15 * 1e6 + 2), + (15.01, 15.01 * 1e6)) + ) + + def test_timestamps_during_and_after_backwards_system_time(self): + """ + Timestamps should increase monotonically over system time going backwards. + + Test that MonotonicTimestampGenerator's output increases by 1 when the + underlying system time goes backward, then returns to normal when the + system time increases again. + """ + self._call_and_check_results( + system_time_expected_stamp_pairs=( + (15.0, 15 * 1e6), + (13.0, 15 * 1e6 + 1), + (14.0, 15 * 1e6 + 2), + (13.5, 15 * 1e6 + 3), + (15.01, 15.01 * 1e6)) + ) + + +class TestTimestampGeneratorLogging(unittest.TestCase): + + def setUp(self): + self.log_patcher = mock.patch('cassandra.timestamps.log') + self.addCleanup(self.log_patcher.stop) + self.patched_timestamp_log = self.log_patcher.start() + + def assertLastCallArgRegex(self, call, pattern): + last_warn_args, last_warn_kwargs = call + self.assertEqual(len(last_warn_args), 1) + self.assertEqual(len(last_warn_kwargs), 0) + self.assertRegex(last_warn_args[0], pattern) + + def test_basic_log_content(self): + """ + Tests there are logs + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=1e-6 + ) + #The units of _last_warn is seconds + tsg._last_warn = 12 + + tsg._next_timestamp(20, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + tsg._next_timestamp(16, tsg.last) + + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + self.assertLastCallArgRegex( + self.patched_timestamp_log.warning.call_args, + r'Clock skew detected:.*\b16\b.*\b4\b.*\b20\b' + ) + + def test_disable_logging(self): + """ + Tests there are no logs when there is a clock skew if logging is disabled + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result no logs + + @test_category timing + """ + no_warn_tsg = timestamps.MonotonicTimestampGenerator(warn_on_drift=False) + + no_warn_tsg.last = 100 + no_warn_tsg._next_timestamp(99, no_warn_tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + + def test_warning_threshold_respected_no_logging(self): + """ + Tests there are no logs if `warning_threshold` is not exceeded + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result no logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=2e-6, + ) + tsg.last, tsg._last_warn = 100, 97 + tsg._next_timestamp(98, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + + def test_warning_threshold_respected_logs(self): + """ + Tests there are logs if `warning_threshold` is exceeded + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=1e-6 + ) + tsg.last, tsg._last_warn = 100, 97 + tsg._next_timestamp(98, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + def test_warning_interval_respected_no_logging(self): + """ + Tests there is only one log in the interval `warning_interval` + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result one log + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_threshold=1e-6, + warning_interval=2e-6 + ) + tsg.last = 100 + tsg._next_timestamp(70, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + tsg._next_timestamp(71, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + def test_warning_interval_respected_logs(self): + """ + Tests there are logs again if the + clock skew happens after`warning_interval` + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result logs + + @test_category timing + """ + tsg = timestamps.MonotonicTimestampGenerator( + warning_interval=1e-6, + warning_threshold=1e-6, + ) + tsg.last = 100 + tsg._next_timestamp(70, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + + tsg._next_timestamp(72, tsg.last) + self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 2) + + +class TestTimestampGeneratorMultipleThreads(unittest.TestCase): + + def test_should_generate_incrementing_timestamps_for_all_threads(self): + """ + Tests when time is "stopped", values are assigned incrementally + + @since 3.8.0 + @jira_ticket PYTHON-676 + @expected_result the returned values increase + + @test_category timing + """ + lock = Lock() + + def request_time(): + for _ in range(timestamp_to_generate): + timestamp = tsg() + with lock: + generated_timestamps.append(timestamp) + + tsg = timestamps.MonotonicTimestampGenerator() + fixed_time = 1 + num_threads = 5 + + timestamp_to_generate = 1000 + generated_timestamps = [] + + with mock.patch('time.time', new=mock.Mock(return_value=fixed_time)): + threads = [] + for _ in range(num_threads): + threads.append(Thread(target=request_time)) + + for t in threads: + t.start() + + for t in threads: + t.join() + + self.assertEqual(len(generated_timestamps), num_threads * timestamp_to_generate) + for i, timestamp in enumerate(sorted(generated_timestamps)): + self.assertEqual(int(i + 1e6), timestamp) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index d8774b0299..ba01538b2a 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1,38 +1,52 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest -from binascii import unhexlify import datetime import tempfile -import six import time +import uuid +from binascii import unhexlify import cassandra -from cassandra.cqltypes import (BooleanType, lookup_casstype_simple, lookup_casstype, - LongType, DecimalType, SetType, cql_typename, - CassandraType, UTF8Type, parse_casstype_args, - SimpleDateType, TimeType, ByteType, ShortType, - EmptyValue, _CassandraType, DateType, int64_pack) +from cassandra import util +from cassandra.cqltypes import ( + CassandraType, DateRangeType, DateType, DecimalType, + EmptyValue, LongType, SetType, UTF8Type, + cql_typename, int8_pack, int64_pack, lookup_casstype, + lookup_casstype_simple, parse_casstype_args, + int32_pack, Int32Type, ListType, MapType, VectorType, + FloatType +) from cassandra.encoder import cql_quote -from cassandra.protocol import (write_string, read_longstring, write_stringmap, - read_stringmap, read_inet, write_inet, - read_string, write_longstring) +from cassandra.pool import Host +from cassandra.metadata import Token +from cassandra.policies import ConvictionPolicy, SimpleConvictionPolicy +from cassandra.protocol import ( + read_inet, read_longstring, read_string, + read_stringmap, write_inet, write_longstring, + write_string, write_stringmap +) from cassandra.query import named_tuple_factory +from cassandra.util import ( + OPEN_BOUND, Date, DateRange, DateRangeBound, + DateRangePrecision, Time, ms_timestamp_from_datetime, + datetime_from_timestamp +) +from tests.unit.util import check_sequence_consistency class TypeTests(unittest.TestCase): @@ -67,6 +81,8 @@ def test_lookup_casstype_simple(self): self.assertEqual(lookup_casstype_simple('CompositeType'), cassandra.cqltypes.CompositeType) self.assertEqual(lookup_casstype_simple('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) self.assertEqual(lookup_casstype_simple('ReversedType'), cassandra.cqltypes.ReversedType) + self.assertEqual(lookup_casstype_simple('DurationType'), cassandra.cqltypes.DurationType) + self.assertEqual(lookup_casstype_simple('DateRangeType'), cassandra.cqltypes.DateRangeType) self.assertEqual(str(lookup_casstype_simple('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) @@ -100,6 +116,8 @@ def test_lookup_casstype(self): self.assertEqual(lookup_casstype('CompositeType'), cassandra.cqltypes.CompositeType) self.assertEqual(lookup_casstype('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) self.assertEqual(lookup_casstype('ReversedType'), cassandra.cqltypes.ReversedType) + self.assertEqual(lookup_casstype('DurationType'), cassandra.cqltypes.DurationType) + self.assertEqual(lookup_casstype('DateRangeType'), cassandra.cqltypes.DateRangeType) self.assertEqual(str(lookup_casstype('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) @@ -150,16 +168,16 @@ def __init__(self, subtypes, names): @classmethod def apply_parameters(cls, subtypes, names): - return cls(subtypes, [unhexlify(six.b(name)) if name is not None else name for name in names]) + return cls(subtypes, [unhexlify(name.encode()) if name is not None else name for name in names]) class BarType(FooType): typename = 'org.apache.cassandra.db.marshal.BarType' ctype = parse_casstype_args(''.join(( 'org.apache.cassandra.db.marshal.FooType(', - '63697479:org.apache.cassandra.db.marshal.UTF8Type,', - 'BarType(61646472657373:org.apache.cassandra.db.marshal.UTF8Type),', - '7a6970:org.apache.cassandra.db.marshal.UTF8Type', + '63697479:org.apache.cassandra.db.marshal.UTF8Type,', + 'BarType(61646472657373:org.apache.cassandra.db.marshal.UTF8Type),', + '7a6970:org.apache.cassandra.db.marshal.UTF8Type', ')'))) self.assertEqual(FooType, ctype.__class__) @@ -174,12 +192,28 @@ class BarType(FooType): self.assertEqual(UTF8Type, ctype.subtypes[2]) self.assertEqual([b'city', None, b'zip'], ctype.names) + def test_parse_casstype_vector(self): + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)") + self.assertTrue(issubclass(ctype, VectorType)) + self.assertEqual(3, ctype.vector_size) + self.assertEqual(FloatType, ctype.subtype) + + def test_parse_casstype_vector_of_vectors(self): + inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) + self.assertTrue(issubclass(ctype, VectorType)) + self.assertEqual(3, ctype.vector_size) + sub_ctype = ctype.subtype + self.assertTrue(issubclass(sub_ctype, VectorType)) + self.assertEqual(4, sub_ctype.vector_size) + self.assertEqual(FloatType, sub_ctype.subtype) + def test_empty_value(self): self.assertEqual(str(EmptyValue()), 'EMPTY') def test_datetype(self): now_time_seconds = time.time() - now_datetime = datetime.datetime.utcfromtimestamp(now_time_seconds) + now_datetime = datetime.datetime.fromtimestamp(now_time_seconds, tz=datetime.timezone.utc) # Cassandra timestamps in millis now_timestamp = now_time_seconds * 1e3 @@ -190,23 +224,63 @@ def test_datetype(self): # deserialize # epoc expected = 0 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime.utcfromtimestamp(expected)) + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None)) # beyond 32b expected = 2 ** 33 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(2242, 3, 16, 12, 56, 32)) + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) # less than epoc (PYTHON-119) expected = -770172256 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(1945, 8, 5, 23, 15, 44)) + self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) # work around rounding difference among Python versions (PYTHON-230) expected = 1424817268.274 - self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2015, 2, 24, 22, 34, 28, 274000)) + self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) # Large date overflow (PYTHON-452) expected = 2177403010.123 - self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2038, 12, 31, 10, 10, 10, 123000)) + self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) + + def test_collection_null_support(self): + """ + Test that null values in collection are decoded properly. + + @jira_ticket PYTHON-1123 + """ + int_list = ListType.apply_parameters([Int32Type]) + value = ( + int32_pack(2) + # num items + int32_pack(-1) + # size of item1 + int32_pack(4) + # size of item2 + int32_pack(42) # item2 + ) + self.assertEqual( + [None, 42], + int_list.deserialize(value, 3) + ) + + set_list = SetType.apply_parameters([Int32Type]) + self.assertEqual( + {None, 42}, + set(set_list.deserialize(value, 3)) + ) + + value = ( + int32_pack(2) + # num items + int32_pack(4) + # key size of item1 + int32_pack(42) + # key item1 + int32_pack(-1) + # value size of item1 + int32_pack(-1) + # key size of item2 + int32_pack(4) + # value size of item2 + int32_pack(42) # value of item2 + ) + + map_list = MapType.apply_parameters([Int32Type, Int32Type]) + self.assertEqual( + [(42, None), (None, 42)], + map_list.deserialize(value, 3)._items # OrderedMapSerializedKey + ) def test_write_read_string(self): with tempfile.TemporaryFile() as f: @@ -246,3 +320,824 @@ def test_cql_quote(self): self.assertEqual(cql_quote(u'test'), "'test'") self.assertEqual(cql_quote('test'), "'test'") self.assertEqual(cql_quote(0), '0') + + +class VectorTests(unittest.TestCase): + def _normalize_set(self, val): + if isinstance(val, set) or isinstance(val, util.SortedSet): + return frozenset([self._normalize_set(v) for v in val]) + return val + + def _round_trip_compare_fn(self, first, second): + if isinstance(first, float): + self.assertAlmostEqual(first, second, places=5) + elif isinstance(first, list): + self.assertEqual(len(first), len(second)) + for (felem, selem) in zip(first, second): + self._round_trip_compare_fn(felem, selem) + elif isinstance(first, set) or isinstance(first, frozenset): + self.assertEqual(len(first), len(second)) + first_norm = self._normalize_set(first) + second_norm = self._normalize_set(second) + self.assertEqual(first_norm, second_norm) + elif isinstance(first, dict): + for ((fk,fv), (sk,sv)) in zip(first.items(), second.items()): + self._round_trip_compare_fn(fk, sk) + self._round_trip_compare_fn(fv, sv) + else: + self.assertEqual(first,second) + + def _round_trip_test(self, data, ctype_str): + ctype = parse_casstype_args(ctype_str) + data_bytes = ctype.serialize(data, 0) + serialized_size = ctype.subtype.serial_size() + if serialized_size: + self.assertEqual(serialized_size * len(data), len(data_bytes)) + result = ctype.deserialize(data_bytes, 0) + self.assertEqual(len(data), len(result)) + for idx in range(0,len(data)): + self._round_trip_compare_fn(data[idx], result[idx]) + + def test_round_trip_basic_types_with_fixed_serialized_size(self): + self._round_trip_test([True, False, False, True], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)") + self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)") + self._round_trip_test([3, 2, 41, 12], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)") + self._round_trip_test([3, 2, 41, 12], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)") + self._round_trip_test([uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)") + self._round_trip_test([3, 2, 41, 12], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)") + + def test_round_trip_basic_types_without_fixed_serialized_size(self): + # Varints + self._round_trip_test([3, 2, 41, 12], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + # ASCII text + self._round_trip_test(["abc", "def", "ghi", "jkl"], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)") + # UTF8 text + self._round_trip_test(["abc", "def", "ghi", "jkl"], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)") + # Time is something of a weird one. By rights, it should be a fixed size type but C* code marks it as variable + # size. We're forced to follow the C* code base (since that's who'll be providing the data we're parsing) so + # we match what they're doing. + self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)") + # Duration (contains varints) + self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)") + + def test_round_trip_collection_types(self): + # List (subtype of fixed size) + self._round_trip_test([[1, 2, 3, 4], [5, 6], [7, 8, 9, 10], [11, 12]], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType \ + (org.apache.cassandra.db.marshal.Int32Type), 4)") + # Set (subtype of fixed size) + self._round_trip_test([set([1, 2, 3, 4]), set([5, 6]), set([7, 8, 9, 10]), set([11, 12])], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType \ + (org.apache.cassandra.db.marshal.Int32Type), 4)") + # Map (subtype of fixed size) + self._round_trip_test([{1:1.2}, {2:3.4}, {3:5.6}, {4:7.8}], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ + (org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)") + # List (subtype without fixed size) + self._round_trip_test([["one","two"], ["three","four"], ["five","six"], ["seven","eight"]], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType \ + (org.apache.cassandra.db.marshal.AsciiType), 4)") + # Set (subtype without fixed size) + self._round_trip_test([set(["one","two"]), set(["three","four"]), set(["five","six"]), set(["seven","eight"])], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType \ + (org.apache.cassandra.db.marshal.AsciiType), 4)") + # Map (subtype without fixed size) + self._round_trip_test([{1:"one"}, {2:"two"}, {3:"three"}, {4:"four"}], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ + (org.apache.cassandra.db.marshal.IntegerType,org.apache.cassandra.db.marshal.AsciiType), 4)") + # List of lists (subtype without fixed size) + data = [[["one","two"],["three"]], [["four"],["five"]], [["six","seven","eight"]], [["nine"]]] + ctype = "org.apache.cassandra.db.marshal.VectorType\ + (org.apache.cassandra.db.marshal.ListType\ + (org.apache.cassandra.db.marshal.ListType\ + (org.apache.cassandra.db.marshal.AsciiType)), 4)" + self._round_trip_test(data, ctype) + # Set of sets (subtype without fixed size) + data = [set([frozenset(["one","two"]),frozenset(["three"])]),\ + set([frozenset(["four"]),frozenset(["five"])]),\ + set([frozenset(["six","seven","eight"])]), + set([frozenset(["nine"])])] + ctype = "org.apache.cassandra.db.marshal.VectorType\ + (org.apache.cassandra.db.marshal.SetType\ + (org.apache.cassandra.db.marshal.SetType\ + (org.apache.cassandra.db.marshal.AsciiType)), 4)" + self._round_trip_test(data, ctype) + # Map of maps (subtype without fixed size) + data = [{100:{1:"one",2:"two",3:"three"}},\ + {200:{4:"four",5:"five"}},\ + {300:{}},\ + {400:{6:"six"}}] + ctype = "org.apache.cassandra.db.marshal.VectorType\ + (org.apache.cassandra.db.marshal.MapType\ + (org.apache.cassandra.db.marshal.Int32Type,\ + org.apache.cassandra.db.marshal.MapType \ + (org.apache.cassandra.db.marshal.IntegerType,org.apache.cassandra.db.marshal.AsciiType)), 4)" + self._round_trip_test(data, ctype) + + def test_round_trip_vector_of_vectors(self): + # Subytpes of subtypes with a fixed size + self._round_trip_test([[1.2, 3.4], [5.6, 7.8], [9.10, 11.12], [13.14, 15.16]], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \ + (org.apache.cassandra.db.marshal.FloatType,2), 4)") + + # Subytpes of subtypes without a fixed size + self._round_trip_test([["one", "two"], ["three", "four"], ["five", "six"], ["seven", "eight"]], \ + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \ + (org.apache.cassandra.db.marshal.AsciiType,2), 4)") + + # parse_casstype_args() is tested above... we're explicitly concerned about cql_parameterized_type() output here + def test_cql_parameterized_type(self): + # Base vector functionality + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType") + + # Test vector-of-vectors + inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) + inner_parsed_type = "org.apache.cassandra.db.marshal.VectorType" + self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type)) + + def test_serialization_fixed_size_too_small(self): + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") + with self.assertRaisesRegex(ValueError, "Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4"): + ctype.serialize([1.2, 3.4, 5.6, 7.8], 0) + + def test_serialization_fixed_size_too_big(self): + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + with self.assertRaisesRegex(ValueError, "Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5"): + ctype.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) + + def test_serialization_variable_size_too_small(self): + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") + with self.assertRaisesRegex(ValueError, "Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4"): + ctype.serialize([1, 2, 3, 4], 0) + + def test_serialization_variable_size_too_big(self): + ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + with self.assertRaisesRegex(ValueError, "Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5"): + ctype.serialize([1, 2, 3, 4, 5], 0) + + def test_deserialization_fixed_size_too_small(self): + ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + ctype_four_bytes = ctype_four.serialize([1.2, 3.4, 5.6, 7.8], 0) + ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") + with self.assertRaisesRegex(ValueError, "Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead"): + ctype_five.deserialize(ctype_four_bytes, 0) + + def test_deserialization_fixed_size_too_big(self): + ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") + ctype_five_bytes = ctype_five.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) + ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") + with self.assertRaisesRegex(ValueError, "Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead"): + ctype_four.deserialize(ctype_five_bytes, 0) + + def test_deserialization_variable_size_too_small(self): + ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + ctype_four_bytes = ctype_four.serialize([1, 2, 3, 4], 0) + ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") + with self.assertRaisesRegex(ValueError, "Error reading additional data during vector deserialization after successfully adding 4 elements"): + ctype_five.deserialize(ctype_four_bytes, 0) + + def test_deserialization_variable_size_too_big(self): + ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") + ctype_five_bytes = ctype_five.serialize([1, 2, 3, 4, 5], 0) + ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") + with self.assertRaisesRegex(ValueError, "Additional bytes remaining after vector deserialization completed"): + ctype_four.deserialize(ctype_five_bytes, 0) + + +ZERO = datetime.timedelta(0) + + +class UTC(datetime.tzinfo): + """UTC""" + + def utcoffset(self, dt): + return ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return ZERO + + +try: + utc_timezone = datetime.timezone.utc +except AttributeError: + utc_timezone = UTC() + + +class DateRangeTypeTests(unittest.TestCase): + dt = datetime.datetime(1990, 2, 3, 13, 58, 45, 777777) + timestamp = 1485963732404 + + def test_month_rounding_creation_failure(self): + """ + @jira_ticket PYTHON-912 + """ + feb_stamp = ms_timestamp_from_datetime( + datetime.datetime(2018, 2, 25, 18, 59, 59, 0) + ) + dr = DateRange(OPEN_BOUND, + DateRangeBound(feb_stamp, DateRangePrecision.MONTH)) + dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) + self.assertEqual(dt.day, 28) + + # Leap year + feb_stamp_leap_year = ms_timestamp_from_datetime( + datetime.datetime(2016, 2, 25, 18, 59, 59, 0) + ) + dr = DateRange(OPEN_BOUND, + DateRangeBound(feb_stamp_leap_year, DateRangePrecision.MONTH)) + dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) + self.assertEqual(dt.day, 29) + + def test_decode_precision(self): + self.assertEqual(DateRangeType._decode_precision(6), 'MILLISECOND') + + def test_decode_precision_error(self): + with self.assertRaises(ValueError): + DateRangeType._decode_precision(-1) + + def test_encode_precision(self): + self.assertEqual(DateRangeType._encode_precision('SECOND'), 5) + + def test_encode_precision_error(self): + with self.assertRaises(ValueError): + DateRangeType._encode_precision('INVALID') + + def test_deserialize_single_value(self): + serialized = (int8_pack(0) + + int64_pack(self.timestamp) + + int8_pack(3)) + self.assertEqual( + DateRangeType.deserialize(serialized, 5), + util.DateRange(value=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), + precision='HOUR') + ) + ) + + def test_deserialize_closed_range(self): + serialized = (int8_pack(1) + + int64_pack(self.timestamp) + + int8_pack(2) + + int64_pack(self.timestamp) + + int8_pack(6)) + self.assertEqual( + DateRangeType.deserialize(serialized, 5), + util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 0, 0), + precision='DAY' + ), + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), + precision='MILLISECOND' + ) + ) + ) + + def test_deserialize_open_high(self): + serialized = (int8_pack(2) + + int64_pack(self.timestamp) + + int8_pack(3)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 0), + precision='HOUR' + ), + upper_bound=util.OPEN_BOUND + ) + ) + + def test_deserialize_open_low(self): + serialized = (int8_pack(3) + + int64_pack(self.timestamp) + + int8_pack(4)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 20, 1000), + precision='MINUTE' + ) + ) + ) + + def test_deserialize_single_open(self): + self.assertEqual( + util.DateRange(value=util.OPEN_BOUND), + DateRangeType.deserialize(int8_pack(5), 5) + ) + + def test_serialize_single_value(self): + serialized = (int8_pack(0) + + int64_pack(self.timestamp) + + int8_pack(5)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + value=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12), + precision='SECOND' + ) + ) + ) + + def test_serialize_closed_range(self): + serialized = (int8_pack(1) + + int64_pack(self.timestamp) + + int8_pack(5) + + int64_pack(self.timestamp) + + int8_pack(0)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12), + precision='SECOND' + ), + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 12, 31), + precision='YEAR' + ) + ) + ) + + def test_serialize_open_high(self): + serialized = (int8_pack(2) + + int64_pack(self.timestamp) + + int8_pack(2)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1), + precision='DAY' + ), + upper_bound=util.OPEN_BOUND + ) + ) + + def test_serialize_open_low(self): + serialized = (int8_pack(2) + + int64_pack(self.timestamp) + + int8_pack(3)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15), + precision='HOUR' + ), + upper_bound=util.OPEN_BOUND + ) + ) + + def test_deserialize_both_open(self): + serialized = (int8_pack(4)) + deserialized = DateRangeType.deserialize(serialized, 5) + self.assertEqual( + deserialized, + util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND + ) + ) + + def test_serialize_single_open(self): + serialized = DateRangeType.serialize(util.DateRange( + value=util.OPEN_BOUND, + ), 5) + self.assertEqual(int8_pack(5), serialized) + + def test_serialize_both_open(self): + serialized = DateRangeType.serialize(util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND + ), 5) + self.assertEqual(int8_pack(4), serialized) + + def test_failure_to_serialize_no_value_object(self): + self.assertRaises(ValueError, DateRangeType.serialize, object(), 5) + + def test_failure_to_serialize_no_bounds_object(self): + class no_bounds_object(object): + value = lower_bound = None + self.assertRaises(ValueError, DateRangeType.serialize, no_bounds_object, 5) + + def test_serialized_value_round_trip(self): + vals = [b'\x01\x00\x00\x01%\xe9a\xf9\xd1\x06\x00\x00\x01v\xbb>o\xff\x00', + b'\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00'] + for serialized in vals: + self.assertEqual( + serialized, + DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) + ) + + def test_serialize_zero_datetime(self): + """ + Test serialization where timestamp = 0 + + Companion test for test_deserialize_zero_datetime + + @since 2.0.0 + @jira_ticket PYTHON-729 + @expected_result serialization doesn't raise an error + + @test_category data_types + """ + DateRangeType.serialize(util.DateRange( + lower_bound=(datetime.datetime(1970, 1, 1), 'YEAR'), + upper_bound=(datetime.datetime(1970, 1, 1), 'YEAR') + ), 5) + + def test_deserialize_zero_datetime(self): + """ + Test deserialization where timestamp = 0 + + Reproduces PYTHON-729 + + @since 2.0.0 + @jira_ticket PYTHON-729 + @expected_result deserialization doesn't raise an error + + @test_category data_types + """ + DateRangeType.deserialize( + (int8_pack(1) + + int64_pack(0) + int8_pack(0) + + int64_pack(0) + int8_pack(0)), + 5 + ) + + +class DateRangeDeserializationTests(unittest.TestCase): + """ + These tests iterate over different timestamp values + and assert deserialization gives the expected value + """ + + starting_lower_value = 1514744108923 + """ + Sample starting value for the lower bound for DateRange + """ + starting_upper_value = 2148761288922 + """ + Sample starting value for the upper bound for DateRange + """ + + epoch = datetime.datetime(1970, 1, 1, tzinfo=utc_timezone) + + def test_deserialize_date_range_milliseconds(self): + """ + Test rounding from DateRange for milliseconds + + @since 2.0.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + for i in range(1000): + lower_value = self.starting_lower_value + i + upper_value = self.starting_upper_value + i + dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.MILLISECOND), + DateRangeBound(upper_value, DateRangePrecision.MILLISECOND)) + self.assertEqual(lower_value, dr.lower_bound.milliseconds) + self.assertEqual(upper_value, dr.upper_bound.milliseconds) + + def test_deserialize_date_range_seconds(self): + """ + Test rounding from DateRange for milliseconds + + @since 2.0.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + + def truncate_last_figures(number, n=3): + """ + Truncates last n digits of a number + """ + return int(str(number)[:-n] + '0' * n) + + for i in range(1000): + lower_value = self.starting_lower_value + i * 900 + upper_value = self.starting_upper_value + i * 900 + dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.SECOND), + DateRangeBound(upper_value, DateRangePrecision.SECOND)) + + self.assertEqual(truncate_last_figures(lower_value), dr.lower_bound.milliseconds) + upper_value = truncate_last_figures(upper_value) + 999 + self.assertEqual(upper_value, dr.upper_bound.milliseconds) + + def test_deserialize_date_range_minutes(self): + """ + Test rounding from DateRange for seconds + + @since 2.4.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + self._deserialize_date_range({"second": 0, "microsecond": 0}, + DateRangePrecision.MINUTE, + # This lambda function given a truncated date adds + # one day minus one microsecond in microseconds + lambda x: x + 59 * 1000 + 999, + lambda original_value, i: original_value + i * 900 * 50) + + def test_deserialize_date_range_hours(self): + """ + Test rounding from DateRange for hours + + @since 2.4.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + self._deserialize_date_range({"minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.HOUR, + # This lambda function given a truncated date adds + # one hour minus one microsecond in microseconds + lambda x: x + + 59 * 60 * 1000 + + 59 * 1000 + + 999, + lambda original_value, i: original_value + i * 900 * 50 * 60) + + def test_deserialize_date_range_day(self): + """ + Test rounding from DateRange for hours + + @since 2.4.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + self._deserialize_date_range({"hour": 0, "minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.DAY, + # This lambda function given a truncated date adds + # one day minus one microsecond in microseconds + lambda x: x + + 23 * 60 * 60 * 1000 + + 59 * 60 * 1000 + + 59 * 1000 + + 999, + lambda original_value, i: original_value + i * 900 * 50 * 60 * 24) + + @unittest.skip("This is currently failing, see PYTHON-912") + def test_deserialize_date_range_month(self): + """ + Test rounding from DateRange for months + + @since 2.4.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + def get_upper_bound(seconds): + """ + function that given a truncated date in seconds from the epoch returns that same date + but with the microseconds set to 999999, seconds to 59, minutes to 59, hours to 23 + and days 28, 29, 30 or 31 depending on the month. + The way to do this is to add one month and leave the date at YEAR-MONTH-01 00:00:00 000000. + Then subtract one millisecond. + """ + dt = datetime.datetime.fromtimestamp(seconds / 1000.0, tz=utc_timezone) + dt = dt + datetime.timedelta(days=32) + dt = dt.replace(day=1) - datetime.timedelta(microseconds=1) + return int((dt - self.epoch).total_seconds() * 1000) + self._deserialize_date_range({"day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.MONTH, + get_upper_bound, + lambda original_value, i: original_value + i * 900 * 50 * 60 * 24 * 30) + + def test_deserialize_date_range_year(self): + """ + Test rounding from DateRange for year + + @since 2.4.0 + @jira_ticket PYTHON-898 + @expected_result + + @test_category data_types + """ + def get_upper_bound(seconds): + """ + function that given a truncated date in seconds from the epoch returns that same date + but with the microseconds set to 999999, seconds to 59, minutes to 59, hours to 23 + days 28, 29, 30 or 31 depending on the month and months to 12. + The way to do this is to add one year and leave the date at YEAR-01-01 00:00:00 000000. + Then subtract one millisecond. + """ + dt = datetime.datetime.fromtimestamp(seconds / 1000.0, tz=utc_timezone) + dt = dt + datetime.timedelta(days=370) + dt = dt.replace(day=1) - datetime.timedelta(microseconds=1) + + diff = time.mktime(dt.timetuple()) - time.mktime(self.epoch.timetuple()) + return diff * 1000 + 999 + # This doesn't work for big values because it loses precision + #return int((dt - self.epoch).total_seconds() * 1000) + self._deserialize_date_range({"month": 1, "day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}, + DateRangePrecision.YEAR, + get_upper_bound, + lambda original_value, i: original_value + i * 900 * 50 * 60 * 24 * 30 * 12 * 7) + + def _deserialize_date_range(self, truncate_kwargs, precision, + round_up_truncated_upper_value, increment_loop_variable): + """ + This functions iterates over several DateRange objects determined by + lower_value upper_value which are given as a value that represents seconds since the epoch. + We want to make sure the lower_value is correctly rounded down and the upper value is correctly rounded up. + In the case of rounding down we verify that the rounded down value + has the appropriate fields set to the minimum they could possibly have. That is + 1 for months, 1 for days, 0 for hours, 0 for minutes, 0 for seconds, 0 for microseconds. + We use the generic function truncate_date which depends on truncate_kwargs for this + + In the case of rounding up we verify that the rounded up value has the appropriate fields set + to the maximum they could possibly have. This is calculated by round_up_truncated_upper_value + which input is the truncated value from before. It is passed as an argument as the way + of calculating this is different for every precision. + + :param truncate_kwargs: determine what values to truncate in truncate_date + :param precision: :class:`~util.DateRangePrecision` + :param round_up_truncated_upper_value: this is a function that gets a truncated date and + returns a new date with some fields set to the maximum possible value + :param increment_loop_variable: this is a function that given a starting value and the iteration + value returns a new date to serve as lower_bound/upper_bound. We need this because the value by which + dates are incremented depends on if the precision is seconds, minutes, hours, days and months + :return: + """ + + def truncate_date(number): + """ + Given a date in seconds since the epoch truncates ups to a certain precision depending on + truncate_kwargs. + The return is the truncated date in seconds since the epoch. + For example if truncate_kwargs = {"hour": 0, "minute": 0, "second": 0, "microsecond": 0} the returned + value will be the original given date but with the hours, minutes, seconds and microseconds set to 0 + """ + dt = datetime.datetime.fromtimestamp(number / 1000.0, tz=utc_timezone) + dt = dt.replace(**truncate_kwargs) + return round((dt - self.epoch).total_seconds() * 1000.0) + + for i in range(1000): + # We increment the lower_value and upper_value according to increment_loop_variable + lower_value = increment_loop_variable(self.starting_lower_value, i) + upper_value = increment_loop_variable(self.starting_upper_value, i) + + # Inside the __init__ for DateRange the rounding up and down should happen + dr = DateRange(DateRangeBound(lower_value, precision), + DateRangeBound(upper_value, precision)) + + # We verify that rounded value corresponds with what we would expect + self.assertEqual(truncate_date(lower_value), dr.lower_bound.milliseconds) + upper_value = round_up_truncated_upper_value(truncate_date(upper_value)) + self.assertEqual(upper_value, dr.upper_bound.milliseconds) + + +class TestOrdering(unittest.TestCase): + def _shuffle_lists(self, *args): + return [item for sublist in zip(*args) for item in sublist] + + def test_host_order(self): + """ + Test Host class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the hosts are ordered correctly + + @test_category data_types + """ + hosts = [Host(addr, SimpleConvictionPolicy) for addr in + ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] + hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in + ("127.0.0.1", "127.0.0.1")] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] + check_sequence_consistency(self, hosts) + check_sequence_consistency(self, hosts_equal, equal=True) + check_sequence_consistency(self, hosts_equal_conviction, equal=True) + + def test_date_order(self): + """ + Test Date class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the dates are ordered correctly + + @test_category data_types + """ + dates_from_string = [Date("2017-01-01"), Date("2017-01-05"), Date("2017-01-09"), Date("2017-01-13")] + dates_from_string_equal = [Date("2017-01-01"), Date("2017-01-01")] + check_sequence_consistency(self, dates_from_string) + check_sequence_consistency(self, dates_from_string_equal, equal=True) + + date_format = "%Y-%m-%d" + + dates_from_value = [ + Date((datetime.datetime.strptime(dtstr, date_format) - + datetime.datetime(1970, 1, 1)).days) + for dtstr in ("2017-01-02", "2017-01-06", "2017-01-10", "2017-01-14") + ] + dates_from_value_equal = [Date(1), Date(1)] + check_sequence_consistency(self, dates_from_value) + check_sequence_consistency(self, dates_from_value_equal, equal=True) + + dates_from_datetime = [Date(datetime.datetime.strptime(dtstr, date_format)) + for dtstr in ("2017-01-03", "2017-01-07", "2017-01-11", "2017-01-15")] + dates_from_datetime_equal = [Date(datetime.datetime.strptime("2017-01-01", date_format)), + Date(datetime.datetime.strptime("2017-01-01", date_format))] + check_sequence_consistency(self, dates_from_datetime) + check_sequence_consistency(self, dates_from_datetime_equal, equal=True) + + dates_from_date = [ + Date(datetime.datetime.strptime(dtstr, date_format).date()) for dtstr in + ("2017-01-04", "2017-01-08", "2017-01-12", "2017-01-16") + ] + dates_from_date_equal = [datetime.datetime.strptime(dtstr, date_format) for dtstr in + ("2017-01-09", "2017-01-9")] + + check_sequence_consistency(self, dates_from_date) + check_sequence_consistency(self, dates_from_date_equal, equal=True) + + check_sequence_consistency(self, self._shuffle_lists(dates_from_string, dates_from_value, + dates_from_datetime, dates_from_date)) + + def test_timer_order(self): + """ + Test Time class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the times are ordered correctly + + @test_category data_types + """ + time_from_int = [Time(1000), Time(4000), Time(7000), Time(10000)] + time_from_int_equal = [Time(1), Time(1)] + check_sequence_consistency(self, time_from_int) + check_sequence_consistency(self, time_from_int_equal, equal=True) + + time_from_datetime = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (2, 5, 8, 11)] + time_from_datetime_equal = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) + for us in (1, 1)] + check_sequence_consistency(self, time_from_datetime) + check_sequence_consistency(self, time_from_datetime_equal, equal=True) + + time_from_string = [Time("00:00:00.000003000"), Time("00:00:00.000006000"), + Time("00:00:00.000009000"), Time("00:00:00.000012000")] + time_from_string_equal = [Time("00:00:00.000004000"), Time("00:00:00.000004000")] + check_sequence_consistency(self, time_from_string) + check_sequence_consistency(self, time_from_string_equal, equal=True) + + check_sequence_consistency(self, self._shuffle_lists(time_from_int, time_from_datetime, time_from_string)) + + def test_token_order(self): + """ + Test Token class is ordered consistently + + @since 3.9 + @jira_ticket PYTHON-714 + @expected_result the tokens are ordered correctly + + @test_category data_types + """ + tokens = [Token(1), Token(2), Token(3), Token(4)] + tokens_equal = [Token(1), Token(1)] + check_sequence_consistency(self, tokens) + check_sequence_consistency(self, tokens_equal, equal=True) diff --git a/tests/unit/test_util_types.py b/tests/unit/test_util_types.py index 53ab9d0752..779d416923 100644 --- a/tests/unit/test_util_types.py +++ b/tests/unit/test_util_types.py @@ -1,24 +1,23 @@ -# Copyright 2013-2016 DataStax, Inc. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa +import unittest import datetime -from cassandra.util import Date, Time +from cassandra.util import Date, Time, Duration, Version class DateTests(unittest.TestCase): @@ -54,7 +53,7 @@ def test_limits(self): max_builtin = Date(datetime.date(9999, 12, 31)) self.assertEqual(Date(min_builtin.days_from_epoch), min_builtin) self.assertEqual(Date(max_builtin.days_from_epoch), max_builtin) - # just proving we can construct with on offset outside buildin range + # just proving we can construct with on offset outside builtin range self.assertEqual(Date(min_builtin.days_from_epoch - 1).days_from_epoch, min_builtin.days_from_epoch - 1) self.assertEqual(Date(max_builtin.days_from_epoch + 1).days_from_epoch, @@ -131,6 +130,11 @@ def test_from_time(self): tt = Time(expected_time) self.assertEqual(tt, expected_time) + def test_as_time(self): + expected_time = datetime.time(12, 1, 2, 3) + tt = Time(expected_time) + self.assertEqual(tt.time(), expected_time) + def test_equals(self): # util.Time self equality self.assertEqual(Time(1234), Time(1234)) @@ -145,3 +149,143 @@ def test_invalid_init(self): self.assertRaises(TypeError, Time, 1.234) self.assertRaises(ValueError, Time, 123456789000000) self.assertRaises(TypeError, Time, datetime.datetime(2004, 12, 23, 11, 11, 1)) + + +class DurationTests(unittest.TestCase): + + def test_valid_format(self): + + valid = Duration(1, 1, 1) + self.assertEqual(valid.months, 1) + self.assertEqual(valid.days, 1) + self.assertEqual(valid.nanoseconds, 1) + + valid = Duration(nanoseconds=100000) + self.assertEqual(valid.months, 0) + self.assertEqual(valid.days, 0) + self.assertEqual(valid.nanoseconds, 100000) + + valid = Duration() + self.assertEqual(valid.months, 0) + self.assertEqual(valid.days, 0) + self.assertEqual(valid.nanoseconds, 0) + + valid = Duration(-10, -21, -1000) + self.assertEqual(valid.months, -10) + self.assertEqual(valid.days, -21) + self.assertEqual(valid.nanoseconds, -1000) + + def test_equality(self): + + first = Duration(1, 1, 1) + second = Duration(-1, 1, 1) + self.assertNotEqual(first, second) + + first = Duration(1, 1, 1) + second = Duration(1, 1, 1) + self.assertEqual(first, second) + + first = Duration() + second = Duration(0, 0, 0) + self.assertEqual(first, second) + + first = Duration(1000, 10000, 2345345) + second = Duration(1000, 10000, 2345345) + self.assertEqual(first, second) + + first = Duration(12, 0, 100) + second = Duration(nanoseconds=100, months=12) + self.assertEqual(first, second) + + def test_str(self): + + self.assertEqual(str(Duration(1, 1, 1)), "1mo1d1ns") + self.assertEqual(str(Duration(1, 1, -1)), "-1mo1d1ns") + self.assertEqual(str(Duration(1, 1, 1000000000000000)), "1mo1d1000000000000000ns") + self.assertEqual(str(Duration(52, 23, 564564)), "52mo23d564564ns") + + +class VersionTests(unittest.TestCase): + + def test_version_parsing(self): + versions = [ + ('2.0.0', (2, 0, 0, 0, 0)), + ('3.1.0', (3, 1, 0, 0, 0)), + ('2.4.54', (2, 4, 54, 0, 0)), + ('3.1.1.12', (3, 1, 1, 12, 0)), + ('3.55.1.build12', (3, 55, 1, 'build12', 0)), + ('3.55.1.20190429-TEST', (3, 55, 1, 20190429, 'TEST')), + ('4.0-SNAPSHOT', (4, 0, 0, 0, 'SNAPSHOT')), + ('1.0.5.4.3', (1, 0, 5, 4, 0)), + ('1-SNAPSHOT', (1, 0, 0, 0, 'SNAPSHOT')), + ('4.0.1.2.3.4.5-ABC-123-SNAP-TEST.blah', (4, 0, 1, 2, 'ABC-123-SNAP-TEST.blah')), + ('2.1.hello', (2, 1, 0, 0, 0)), + ('2.test.1', (2, 0, 0, 0, 0)), + ] + + for str_version, expected_result in versions: + v = Version(str_version) + self.assertEqual(str_version, str(v)) + self.assertEqual(v.major, expected_result[0]) + self.assertEqual(v.minor, expected_result[1]) + self.assertEqual(v.patch, expected_result[2]) + self.assertEqual(v.build, expected_result[3]) + self.assertEqual(v.prerelease, expected_result[4]) + + # not supported version formats + with self.assertRaises(ValueError): + Version('test.1.0') + + def test_version_compare(self): + # just tests a bunch of versions + + # major wins + self.assertTrue(Version('3.3.0') > Version('2.5.0')) + self.assertTrue(Version('3.3.0') > Version('2.5.0.66')) + self.assertTrue(Version('3.3.0') > Version('2.5.21')) + + # minor wins + self.assertTrue(Version('2.3.0') > Version('2.2.0')) + self.assertTrue(Version('2.3.0') > Version('2.2.7')) + self.assertTrue(Version('2.3.0') > Version('2.2.7.9')) + + # patch wins + self.assertTrue(Version('2.3.1') > Version('2.3.0')) + self.assertTrue(Version('2.3.1') > Version('2.3.0.4post0')) + self.assertTrue(Version('2.3.1') > Version('2.3.0.44')) + + # various + self.assertTrue(Version('2.3.0.1') > Version('2.3.0.0')) + self.assertTrue(Version('2.3.0.680') > Version('2.3.0.670')) + self.assertTrue(Version('2.3.0.681') > Version('2.3.0.680')) + self.assertTrue(Version('2.3.0.1build0') > Version('2.3.0.1')) # 4th part fallback to str cmp + self.assertTrue(Version('2.3.0.build0') > Version('2.3.0.1')) # 4th part fallback to str cmp + self.assertTrue(Version('2.3.0') < Version('2.3.0.build')) + + self.assertTrue(Version('4-a') <= Version('4.0.0')) + self.assertTrue(Version('4-a') <= Version('4.0-alpha1')) + self.assertTrue(Version('4-a') <= Version('4.0-beta1')) + self.assertTrue(Version('4.0.0') >= Version('4.0.0')) + self.assertTrue(Version('4.0.0.421') >= Version('4.0.0')) + self.assertTrue(Version('4.0.1') >= Version('4.0.0')) + self.assertTrue(Version('2.3.0') == Version('2.3.0')) + self.assertTrue(Version('2.3.32') == Version('2.3.32')) + self.assertTrue(Version('2.3.32') == Version('2.3.32.0')) + self.assertTrue(Version('2.3.0.build') == Version('2.3.0.build')) + + self.assertTrue(Version('4') == Version('4.0.0')) + self.assertTrue(Version('4.0') == Version('4.0.0.0')) + self.assertTrue(Version('4.0') > Version('3.9.3')) + + self.assertTrue(Version('4.0') > Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0.build5-SNAPSHOT') == Version('4.0.0.build5-SNAPSHOT')) + self.assertTrue(Version('4.1-SNAPSHOT') > Version('4.0-SNAPSHOT')) + self.assertTrue(Version('4.0.1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) + self.assertTrue(Version('4.0.0.build6-SNAPSHOT') > Version('4.0.0.build5-SNAPSHOT')) + self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0-SNAPSHOT1')) + self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0.0-SNAPSHOT1')) + + self.assertTrue(Version('4.0.0-alpha1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) diff --git a/tests/unit/util.py b/tests/unit/util.py new file mode 100644 index 0000000000..e57fa6c3ee --- /dev/null +++ b/tests/unit/util.py @@ -0,0 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def check_sequence_consistency(unit_test, ordered_sequence, equal=False): + for i, el in enumerate(ordered_sequence): + for previous in ordered_sequence[:i]: + _check_order_consistency(unit_test, previous, el, equal) + for posterior in ordered_sequence[i + 1:]: + _check_order_consistency(unit_test, el, posterior, equal) + + +def _check_order_consistency(unit_test, smaller, bigger, equal=False): + unit_test.assertLessEqual(smaller, bigger) + unit_test.assertGreaterEqual(bigger, smaller) + if equal: + unit_test.assertEqual(smaller, bigger) + else: + unit_test.assertNotEqual(smaller, bigger) + unit_test.assertLess(smaller, bigger) + unit_test.assertGreater(bigger, smaller) diff --git a/tests/unit/utils.py b/tests/unit/utils.py new file mode 100644 index 0000000000..fc3ce4b481 --- /dev/null +++ b/tests/unit/utils.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from unittest.mock import patch + +from concurrent.futures import Future +from cassandra.cluster import Session + + +def mock_session_pools(f): + """ + Helper decorator that allows tests to initialize :class:.`Session` objects + without actually connecting to a Cassandra cluster. + """ + @wraps(f) + def wrapper(*args, **kwargs): + with patch.object(Session, "add_or_renew_pool") as mocked_add_or_renew_pool: + future = Future() + future.set_result(object()) + mocked_add_or_renew_pool.return_value = future + f(*args, **kwargs) + return wrapper diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000000..d44d6c91c8 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from functools import wraps + + +def wait_until(condition, delay, max_attempts): + """ + Executes a function at regular intervals while the condition + is false and the amount of attempts < maxAttempts. + :param condition: a function + :param delay: the delay in second + :param max_attempts: the maximum number of attempts. So the timeout + of this function is delay*max_attempts + """ + attempt = 0 + while not condition() and attempt < max_attempts: + attempt += 1 + time.sleep(delay) + + if attempt >= max_attempts: + raise Exception("Condition is still False after {} attempts.".format(max_attempts)) + + +def wait_until_not_raised(condition, delay, max_attempts): + """ + Executes a function at regular intervals while the condition + doesn't raise an exception and the amount of attempts < maxAttempts. + :param condition: a function + :param delay: the delay in second + :param max_attempts: the maximum number of attempts. So the timeout + of this function will be delay*max_attempts + """ + def wrapped_condition(): + try: + result = condition() + except: + return False, None + + return True, result + + attempt = 0 + while attempt < (max_attempts-1): + attempt += 1 + success, result = wrapped_condition() + if success: + return result + + time.sleep(delay) + + # last attempt, let the exception raise + return condition() + + +def late(seconds=1): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + time.sleep(seconds) + func(*args, **kwargs) + return wrapper + return decorator diff --git a/tox.ini b/tox.ini index f4d21ba5df..e77835f0da 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,46 @@ [tox] -envlist = py{26,27,33,34},pypy,pypy3 +envlist = py{39,310,311,312,313},pypy [base] -deps = nose - mock<=1.0.1 - PyYAML - six +deps = pytest + packaging + cython>=3.0 + eventlet + gevent + twisted[tls] + pure-sasl + kerberos + futurist + lz4 + cryptography>=42.0 [testenv] deps = {[base]deps} - cython - py26: unittest2 - py{26,27}: gevent - twisted <15.5.0 + +setenv = LIBEV_EMBED=0 + CARES_EMBED=0 + LC_ALL=en_US.UTF-8 +changedir = {envtmpdir} +commands = pytest -v {toxinidir}/tests/unit/ + + +[testenv:gevent_loop] +deps = {[base]deps} + +setenv = LIBEV_EMBED=0 + CARES_EMBED=0 + EVENT_LOOP_MANAGER=gevent +changedir = {envtmpdir} +commands = + pytest -v {toxinidir}/tests/unit/io/test_geventreactor.py + + +[testenv:eventlet_loop] +deps = {[base]deps} + setenv = LIBEV_EMBED=0 CARES_EMBED=0 + EVENT_LOOP_MANAGER=eventlet changedir = {envtmpdir} -commands = nosetests --verbosity=2 --no-path-adjustment {toxinidir}/tests/unit/ +commands = + pytest -v {toxinidir}/tests/unit/io/test_eventletreactor.py