From 38052885bf58c6f32c4bbd0142a221564d53c3ff Mon Sep 17 00:00:00 2001
From: Robert Pang <30942926+robertpang@users.noreply.github.com>
Date: Sun, 20 May 2018 17:14:58 -0700
Subject: [PATCH] [cassandra] Update CassandraCQLClient to use
 PreparedStatement for better performance (#1051)

* Optimize PreparedStatement lookup by looking up by field set and avoiding building the query string unless the statement has not been prepared.
* Add tests for update, delete and prepared statements.
* fix logger calls
* Credit to @haaawk for incorporating some of the feedback
---
 .../com/yahoo/ycsb/db/CassandraCQLClient.java | 335 ++++++++++++------
 .../yahoo/ycsb/db/CassandraCQLClientTest.java |  62 +++-
 2 files changed, 294 insertions(+), 103 deletions(-)

diff --git a/cassandra/src/main/java/com/yahoo/ycsb/db/CassandraCQLClient.java b/cassandra/src/main/java/com/yahoo/ycsb/db/CassandraCQLClient.java
index aefd7798..f83b8b40 100644
--- a/cassandra/src/main/java/com/yahoo/ycsb/db/CassandraCQLClient.java
+++ b/cassandra/src/main/java/com/yahoo/ycsb/db/CassandraCQLClient.java
@@ -26,11 +26,12 @@ import com.datastax.driver.core.Metadata;
 import com.datastax.driver.core.ResultSet;
 import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
-import com.datastax.driver.core.SimpleStatement;
-import com.datastax.driver.core.Statement;
+import com.datastax.driver.core.PreparedStatement;
+import com.datastax.driver.core.BoundStatement;
 import com.datastax.driver.core.querybuilder.Insert;
 import com.datastax.driver.core.querybuilder.QueryBuilder;
 import com.datastax.driver.core.querybuilder.Select;
+import com.datastax.driver.core.querybuilder.Update;
 import com.yahoo.ycsb.ByteArrayByteIterator;
 import com.yahoo.ycsb.ByteIterator;
 import com.yahoo.ycsb.DB;
@@ -39,10 +40,18 @@ import com.yahoo.ycsb.Status;
 
 import java.nio.ByteBuffer;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.Vector;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.slf4j.helpers.MessageFormatter;
 
 /**
  * Cassandra 2.x CQL client.
@@ -53,9 +62,26 @@ import java.util.concurrent.atomic.AtomicInteger;
  */
 public class CassandraCQLClient extends DB {
 
+  private static Logger logger = LoggerFactory.getLogger(CassandraCQLClient.class);
+
   private static Cluster cluster = null;
   private static Session session = null;
 
+  private static ConcurrentMap<Set<String>, PreparedStatement> readStmts =
+      new ConcurrentHashMap<Set<String>, PreparedStatement>();
+  private static ConcurrentMap<Set<String>, PreparedStatement> scanStmts =
+      new ConcurrentHashMap<Set<String>, PreparedStatement>();
+  private static ConcurrentMap<Set<String>, PreparedStatement> insertStmts =
+      new ConcurrentHashMap<Set<String>, PreparedStatement>();
+  private static ConcurrentMap<Set<String>, PreparedStatement> updateStmts =
+      new ConcurrentHashMap<Set<String>, PreparedStatement>();
+  private static AtomicReference<PreparedStatement> readAllStmt =
+      new AtomicReference<PreparedStatement>();
+  private static AtomicReference<PreparedStatement> scanAllStmt =
+      new AtomicReference<PreparedStatement>();
+  private static AtomicReference<PreparedStatement> deleteStmt =
+      new AtomicReference<PreparedStatement>();
+
   private static ConsistencyLevel readConsistencyLevel = ConsistencyLevel.ONE;
   private static ConsistencyLevel writeConsistencyLevel = ConsistencyLevel.ONE;
 
@@ -184,11 +210,11 @@ public class CassandraCQLClient extends DB {
         }
 
         Metadata metadata = cluster.getMetadata();
-        System.err.printf("Connected to cluster: %s\n",
+        logger.info("Connected to cluster: {}\n",
             metadata.getClusterName());
 
         for (Host discoveredHost : metadata.getAllHosts()) {
-          System.out.printf("Datacenter: %s; Host: %s; Rack: %s\n",
+          logger.info("Datacenter: {}; Host: {}; Rack: {}\n",
               discoveredHost.getDatacenter(), discoveredHost.getAddress(),
               discoveredHost.getRack());
         }
@@ -210,6 +236,13 @@ public class CassandraCQLClient extends DB {
     synchronized (INIT_COUNT) {
       final int curInitCount = INIT_COUNT.decrementAndGet();
       if (curInitCount <= 0) {
+        readStmts.clear();
+        scanStmts.clear();
+        insertStmts.clear();
+        updateStmts.clear();
+        readAllStmt.set(null);
+        scanAllStmt.set(null);
+        deleteStmt.set(null);
         session.close();
         cluster.close();
         cluster = null;
@@ -241,30 +274,41 @@ public class CassandraCQLClient extends DB {
   public Status read(String table, String key, Set<String> fields,
       Map<String, ByteIterator> result) {
     try {
-      Statement stmt;
-      Select.Builder selectBuilder;
-
-      if (fields == null) {
-        selectBuilder = QueryBuilder.select().all();
-      } else {
-        selectBuilder = QueryBuilder.select();
-        for (String col : fields) {
-          ((Select.Selection) selectBuilder).column(col);
+      PreparedStatement stmt = (fields == null) ? readAllStmt.get() : readStmts.get(fields);
+
+      // Prepare statement on demand
+      if (stmt == null) {
+        Select.Builder selectBuilder;
+
+        if (fields == null) {
+          selectBuilder = QueryBuilder.select().all();
+        } else {
+          selectBuilder = QueryBuilder.select();
+          for (String col : fields) {
+            ((Select.Selection) selectBuilder).column(col);
+          }
         }
-      }
 
-      stmt = selectBuilder.from(table).where(QueryBuilder.eq(YCSB_KEY, key))
-          .limit(1);
-      stmt.setConsistencyLevel(readConsistencyLevel);
+        stmt = session.prepare(selectBuilder.from(table)
+                               .where(QueryBuilder.eq(YCSB_KEY, QueryBuilder.bindMarker()))
+                               .limit(1));
+        stmt.setConsistencyLevel(readConsistencyLevel);
+        if (trace) {
+          stmt.enableTracing();
+        }
 
-      if (debug) {
-        System.out.println(stmt.toString());
-      }
-      if (trace) {
-        stmt.enableTracing();
+        PreparedStatement prevStmt = (fields == null) ?
+                                     readAllStmt.getAndSet(stmt) :
+                                     readStmts.putIfAbsent(new HashSet(fields), stmt);
+        if (prevStmt != null) {
+          stmt = prevStmt;
+        }
       }
-      
-      ResultSet rs = session.execute(stmt);
+
+      logger.debug(stmt.getQueryString());
+      logger.debug("key = {}", key);
+
+      ResultSet rs = session.execute(stmt.bind(key));
 
       if (rs.isExhausted()) {
         return Status.NOT_FOUND;
@@ -286,8 +330,7 @@ public class CassandraCQLClient extends DB {
       return Status.OK;
 
     } catch (Exception e) {
-      e.printStackTrace();
-      System.out.println("Error reading key: " + key);
+      logger.error(MessageFormatter.format("Error reading key: {}", key).getMessage(), e);
       return Status.ERROR;
     }
 
@@ -318,45 +361,55 @@ public class CassandraCQLClient extends DB {
       Set<String> fields, Vector<HashMap<String, ByteIterator>> result) {
 
     try {
-      Statement stmt;
-      Select.Builder selectBuilder;
-
-      if (fields == null) {
-        selectBuilder = QueryBuilder.select().all();
-      } else {
-        selectBuilder = QueryBuilder.select();
-        for (String col : fields) {
-          ((Select.Selection) selectBuilder).column(col);
+      PreparedStatement stmt = (fields == null) ? scanAllStmt.get() : scanStmts.get(fields);
+
+      // Prepare statement on demand
+      if (stmt == null) {
+        Select.Builder selectBuilder;
+
+        if (fields == null) {
+          selectBuilder = QueryBuilder.select().all();
+        } else {
+          selectBuilder = QueryBuilder.select();
+          for (String col : fields) {
+            ((Select.Selection) selectBuilder).column(col);
+          }
         }
-      }
 
-      stmt = selectBuilder.from(table);
-
-      // The statement builder is not setup right for tokens.
-      // So, we need to build it manually.
-      String initialStmt = stmt.toString();
-      StringBuilder scanStmt = new StringBuilder();
-      scanStmt.append(initialStmt.substring(0, initialStmt.length() - 1));
-      scanStmt.append(" WHERE ");
-      scanStmt.append(QueryBuilder.token(YCSB_KEY));
-      scanStmt.append(" >= ");
-      scanStmt.append("token('");
-      scanStmt.append(startkey);
-      scanStmt.append("')");
-      scanStmt.append(" LIMIT ");
-      scanStmt.append(recordcount);
-
-      stmt = new SimpleStatement(scanStmt.toString());
-      stmt.setConsistencyLevel(readConsistencyLevel);
-
-      if (debug) {
-        System.out.println(stmt.toString());
-      }
-      if (trace) {
-        stmt.enableTracing();
+        Select selectStmt = selectBuilder.from(table);
+
+        // The statement builder is not setup right for tokens.
+        // So, we need to build it manually.
+        String initialStmt = selectStmt.toString();
+        StringBuilder scanStmt = new StringBuilder();
+        scanStmt.append(initialStmt.substring(0, initialStmt.length() - 1));
+        scanStmt.append(" WHERE ");
+        scanStmt.append(QueryBuilder.token(YCSB_KEY));
+        scanStmt.append(" >= ");
+        scanStmt.append("token(");
+        scanStmt.append(QueryBuilder.bindMarker());
+        scanStmt.append(")");
+        scanStmt.append(" LIMIT ");
+        scanStmt.append(QueryBuilder.bindMarker());
+
+        stmt = session.prepare(scanStmt.toString());
+        stmt.setConsistencyLevel(readConsistencyLevel);
+        if (trace) {
+          stmt.enableTracing();
+        }
+
+        PreparedStatement prevStmt = (fields == null) ?
+                                     scanAllStmt.getAndSet(stmt) :
+                                     scanStmts.putIfAbsent(new HashSet(fields), stmt);
+        if (prevStmt != null) {
+          stmt = prevStmt;
+        }
       }
-      
-      ResultSet rs = session.execute(stmt);
+
+      logger.debug(stmt.getQueryString());
+      logger.debug("startKey = {}, recordcount = {}", startkey, recordcount);
+
+      ResultSet rs = session.execute(stmt.bind(startkey, Integer.valueOf(recordcount)));
 
       HashMap<String, ByteIterator> tuple;
       while (!rs.isExhausted()) {
@@ -380,8 +433,8 @@ public class CassandraCQLClient extends DB {
       return Status.OK;
 
     } catch (Exception e) {
-      e.printStackTrace();
-      System.out.println("Error scanning with startkey: " + startkey);
+      logger.error(
+          MessageFormatter.format("Error scanning with startkey: {}", startkey).getMessage(), e);
       return Status.ERROR;
     }
 
@@ -401,10 +454,62 @@ public class CassandraCQLClient extends DB {
    * @return Zero on success, a non-zero error code on error
    */
   @Override
-  public Status update(String table, String key,
-                       Map<String, ByteIterator> values) {
-    // Insert and updates provide the same functionality
-    return insert(table, key, values);
+  public Status update(String table, String key, Map<String, ByteIterator> values) {
+
+    try {
+      Set<String> fields = values.keySet();
+      PreparedStatement stmt = updateStmts.get(fields);
+
+      // Prepare statement on demand
+      if (stmt == null) {
+        Update updateStmt = QueryBuilder.update(table);
+
+        // Add fields
+        for (String field : fields) {
+          updateStmt.with(QueryBuilder.set(field, QueryBuilder.bindMarker()));
+        }
+
+        // Add key
+        updateStmt.where(QueryBuilder.eq(YCSB_KEY, QueryBuilder.bindMarker()));
+
+        stmt = session.prepare(updateStmt);
+        stmt.setConsistencyLevel(writeConsistencyLevel);
+        if (trace) {
+          stmt.enableTracing();
+        }
+
+        PreparedStatement prevStmt = updateStmts.putIfAbsent(new HashSet(fields), stmt);
+        if (prevStmt != null) {
+          stmt = prevStmt;
+        }
+      }
+
+      if (logger.isDebugEnabled()) {
+        logger.debug(stmt.getQueryString());
+        logger.debug("key = {}", key);
+        for (Map.Entry<String, ByteIterator> entry : values.entrySet()) {
+          logger.debug("{} = {}", entry.getKey(), entry.getValue());
+        }
+      }
+
+      // Add fields
+      ColumnDefinitions vars = stmt.getVariables();
+      BoundStatement boundStmt = stmt.bind();
+      for (int i = 0; i < vars.size() - 1; i++) {
+        boundStmt.setString(i, values.get(vars.getName(i)).toString());
+      }
+
+      // Add key
+      boundStmt.setString(vars.size() - 1, key);
+
+      session.execute(boundStmt);
+
+      return Status.OK;
+    } catch (Exception e) {
+      logger.error(MessageFormatter.format("Error updating key: {}", key).getMessage(), e);
+    }
+
+    return Status.ERROR;
   }
 
   /**
@@ -421,38 +526,58 @@ public class CassandraCQLClient extends DB {
    * @return Zero on success, a non-zero error code on error
    */
   @Override
-  public Status insert(String table, String key,
-      Map<String, ByteIterator> values) {
+  public Status insert(String table, String key, Map<String, ByteIterator> values) {
 
     try {
-      Insert insertStmt = QueryBuilder.insertInto(table);
+      Set<String> fields = values.keySet();
+      PreparedStatement stmt = insertStmts.get(fields);
 
-      // Add key
-      insertStmt.value(YCSB_KEY, key);
+      // Prepare statement on demand
+      if (stmt == null) {
+        Insert insertStmt = QueryBuilder.insertInto(table);
 
-      // Add fields
-      for (Map.Entry<String, ByteIterator> entry : values.entrySet()) {
-        Object value;
-        ByteIterator byteIterator = entry.getValue();
-        value = byteIterator.toString();
+        // Add key
+        insertStmt.value(YCSB_KEY, QueryBuilder.bindMarker());
 
-        insertStmt.value(entry.getKey(), value);
-      }
+        // Add fields
+        for (String field : fields) {
+          insertStmt.value(field, QueryBuilder.bindMarker());
+        }
 
-      insertStmt.setConsistencyLevel(writeConsistencyLevel);
+        stmt = session.prepare(insertStmt);
+        stmt.setConsistencyLevel(writeConsistencyLevel);
+        if (trace) {
+          stmt.enableTracing();
+        }
 
-      if (debug) {
-        System.out.println(insertStmt.toString());
+        PreparedStatement prevStmt = insertStmts.putIfAbsent(new HashSet(fields), stmt);
+        if (prevStmt != null) {
+          stmt = prevStmt;
+        }
       }
-      if (trace) {
-        insertStmt.enableTracing();
+
+      if (logger.isDebugEnabled()) {
+        logger.debug(stmt.getQueryString());
+        logger.debug("key = {}", key);
+        for (Map.Entry<String, ByteIterator> entry : values.entrySet()) {
+          logger.debug("{} = {}", entry.getKey(), entry.getValue());
+        }
       }
-      
-      session.execute(insertStmt);
+
+      // Add key
+      BoundStatement boundStmt = stmt.bind().setString(0, key);
+
+      // Add fields
+      ColumnDefinitions vars = stmt.getVariables();
+      for (int i = 1; i < vars.size(); i++) {
+        boundStmt.setString(i, values.get(vars.getName(i)).toString());
+      }
+
+      session.execute(boundStmt);
 
       return Status.OK;
     } catch (Exception e) {
-      e.printStackTrace();
+      logger.error(MessageFormatter.format("Error inserting key: {}", key).getMessage(), e);
     }
 
     return Status.ERROR;
@@ -471,25 +596,31 @@ public class CassandraCQLClient extends DB {
   public Status delete(String table, String key) {
 
     try {
-      Statement stmt;
-
-      stmt = QueryBuilder.delete().from(table)
-          .where(QueryBuilder.eq(YCSB_KEY, key));
-      stmt.setConsistencyLevel(writeConsistencyLevel);
+      PreparedStatement stmt = deleteStmt.get();
+
+      // Prepare statement on demand
+      if (stmt == null) {
+        stmt = session.prepare(QueryBuilder.delete().from(table)
+                               .where(QueryBuilder.eq(YCSB_KEY, QueryBuilder.bindMarker())));
+        stmt.setConsistencyLevel(writeConsistencyLevel);
+        if (trace) {
+          stmt.enableTracing();
+        }
 
-      if (debug) {
-        System.out.println(stmt.toString());
-      }
-      if (trace) {
-        stmt.enableTracing();
+        PreparedStatement prevStmt = deleteStmt.getAndSet(stmt);
+        if (prevStmt != null) {
+          stmt = prevStmt;
+        }
       }
-      
-      session.execute(stmt);
+
+      logger.debug(stmt.getQueryString());
+      logger.debug("key = {}", key);
+
+      session.execute(stmt.bind(key));
 
       return Status.OK;
     } catch (Exception e) {
-      e.printStackTrace();
-      System.out.println("Error deleting key: " + key);
+      logger.error(MessageFormatter.format("Error deleting key: {}", key).getMessage(), e);
     }
 
     return Status.ERROR;
diff --git a/cassandra/src/test/java/com/yahoo/ycsb/db/CassandraCQLClientTest.java b/cassandra/src/test/java/com/yahoo/ycsb/db/CassandraCQLClientTest.java
index 9c136666..f339da2c 100644
--- a/cassandra/src/test/java/com/yahoo/ycsb/db/CassandraCQLClientTest.java
+++ b/cassandra/src/test/java/com/yahoo/ycsb/db/CassandraCQLClientTest.java
@@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
 
 import com.google.common.collect.Sets;
 
@@ -155,7 +156,7 @@ public class CassandraCQLClientTest {
   }
 
   @Test
-  public void testUpdate() throws Exception {
+  public void testInsert() throws Exception {
     final String key = "key";
     final Map<String, String> input = new HashMap<String, String>();
     input.put("field0", "value1");
@@ -178,4 +179,63 @@ public class CassandraCQLClientTest {
     assertThat(row.getString("field0"), is("value1"));
     assertThat(row.getString("field1"), is("value2"));
   }
+
+  @Test
+  public void testUpdate() throws Exception {
+    insertRow();
+    final Map<String, String> input = new HashMap<String, String>();
+    input.put("field0", "new-value1");
+    input.put("field1", "new-value2");
+
+    final Status status = client.update(TABLE,
+                                        DEFAULT_ROW_KEY,
+                                        StringByteIterator.getByteIteratorMap(input));
+    assertThat(status, is(Status.OK));
+
+    // Verify result
+    final Select selectStmt =
+        QueryBuilder.select("field0", "field1")
+            .from(TABLE)
+            .where(QueryBuilder.eq(CassandraCQLClient.YCSB_KEY, DEFAULT_ROW_KEY))
+            .limit(1);
+
+    final ResultSet rs = session.execute(selectStmt);
+    final Row row = rs.one();
+    assertThat(row, notNullValue());
+    assertThat(rs.isExhausted(), is(true));
+    assertThat(row.getString("field0"), is("new-value1"));
+    assertThat(row.getString("field1"), is("new-value2"));
+  }
+
+  @Test
+  public void testDelete() throws Exception {
+    insertRow();
+
+    final Status status = client.delete(TABLE, DEFAULT_ROW_KEY);
+    assertThat(status, is(Status.OK));
+
+    // Verify result
+    final Select selectStmt =
+        QueryBuilder.select("field0", "field1")
+            .from(TABLE)
+            .where(QueryBuilder.eq(CassandraCQLClient.YCSB_KEY, DEFAULT_ROW_KEY))
+            .limit(1);
+
+    final ResultSet rs = session.execute(selectStmt);
+    final Row row = rs.one();
+    assertThat(row, nullValue());
+  }
+
+  @Test
+  public void testPreparedStatements() throws Exception {
+    final int LOOP_COUNT = 3;
+    for (int i = 0; i < LOOP_COUNT; i++) {
+      testInsert();
+      testUpdate();
+      testRead();
+      testReadSingleColumn();
+      testReadMissingRow();
+      testDelete();
+    }
+  }
 }
-- 
GitLab