From b6ce86467a55213ec509a3a1bf3432597325a885 Mon Sep 17 00:00:00 2001
From: Sergei Isakov <iserge@google.com>
Date: Mon, 19 Dec 2022 16:38:43 +0100
Subject: [PATCH] Fix array overflow for small input vectors.

---
 lib/vectorspace.h                |  6 +++---
 lib/vectorspace_cuda.h           |  6 +++---
 pybind_interface/pybind_main.cpp | 10 +++++++++-
 3 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/lib/vectorspace.h b/lib/vectorspace.h
index 246394e1..3b475460 100644
--- a/lib/vectorspace.h
+++ b/lib/vectorspace.h
@@ -102,7 +102,7 @@ class VectorSpace {
   }
 
   // It is the client's responsibility to make sure that p has at least
-  // 2 * 2^num_qubits elements.
+  // Impl::MinSize(num_qubits) elements.
   static Vector Create(fp_type* p, unsigned num_qubits) {
     return Vector{Pointer{p, &detail::do_not_free}, num_qubits};
   }
@@ -135,7 +135,7 @@ class VectorSpace {
   }
 
   // It is the client's responsibility to make sure that dest has at least
-  // 2 * 2^src.num_qubits() elements.
+  // Impl::MinSize(src.num_qubits()) elements.
   bool Copy(const Vector& src, fp_type* dest) const {
     auto f = [](unsigned n, unsigned m, uint64_t i,
                 const fp_type* src, fp_type* dest) {
@@ -148,7 +148,7 @@ class VectorSpace {
   }
 
   // It is the client's responsibility to make sure that src has at least
-  // 2 * 2^dest.num_qubits() elements.
+  // Impl::MinSize(dest.num_qubits()) elements.
   bool Copy(const fp_type* src, Vector& dest) const {
     auto f = [](unsigned n, unsigned m, uint64_t i,
                 const fp_type* src, fp_type* dest) {
diff --git a/lib/vectorspace_cuda.h b/lib/vectorspace_cuda.h
index d26f003f..27d29c23 100644
--- a/lib/vectorspace_cuda.h
+++ b/lib/vectorspace_cuda.h
@@ -92,7 +92,7 @@ class VectorSpaceCUDA {
   }
 
   // It is the client's responsibility to make sure that p has at least
-  // 2 * 2^num_qubits elements.
+  // Impl::MinSize(num_qubits) elements.
   static Vector Create(fp_type* p, unsigned num_qubits) {
     return Vector{Pointer{p, &detail::do_not_free}, num_qubits};
   }
@@ -122,7 +122,7 @@ class VectorSpaceCUDA {
   }
 
   // It is the client's responsibility to make sure that dest has at least
-  // 2 * 2^src.num_qubits() elements.
+  // Impl::MinSize(src.num_qubits()) elements.
   bool Copy(const Vector& src, fp_type* dest) const {
     cudaMemcpy(dest, src.get(),
                sizeof(fp_type) * Impl::MinSize(src.num_qubits()),
@@ -132,7 +132,7 @@ class VectorSpaceCUDA {
   }
 
   // It is the client's responsibility to make sure that src has at least
-  // 2 * 2^dest.num_qubits() elements.
+  // Impl::MinSize(dest.num_qubits()) elements.
   bool Copy(const fp_type* src, Vector& dest) const {
     cudaMemcpy(dest.get(), src,
                sizeof(fp_type) * Impl::MinSize(dest.num_qubits()),
diff --git a/pybind_interface/pybind_main.cpp b/pybind_interface/pybind_main.cpp
index 94d7cefa..b4fe4859 100644
--- a/pybind_interface/pybind_main.cpp
+++ b/pybind_interface/pybind_main.cpp
@@ -691,7 +691,15 @@ class SimulatorHelper {
 
   void init_state(const py::array_t<float> &input_vector) {
     StateSpace state_space = factory.CreateStateSpace();
-    state_space.Copy(input_vector.data(), state);
+    if (state.num_qubits() >= 5) {
+      state_space.Copy(input_vector.data(), state);
+    } else {
+      state_space.SetAllZeros(state);
+      uint64_t size = 2 * (uint64_t{1} << state.num_qubits());
+      for (uint64_t i = 0; i < size; ++i) {
+        state.get()[i] = input_vector.data()[i];
+      }
+    }
     state_space.NormalToInternalOrder(state);
   }