diff --git a/include/matx/transforms/fft.h b/include/matx/transforms/fft.h
index 292369ba0..4711a4c30 100644
--- a/include/matx/transforms/fft.h
+++ b/include/matx/transforms/fft.h
@@ -172,6 +172,9 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
     MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
     FftParams_t params;
 
+    // Default to default stream, but caller will generally overwrite this
+    params.stream = 0;
+
     params.irank = i.Rank();
     params.orank = o.Rank();
 
@@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
  *   Output view
  * @param i
  *   Input view
+ * @param stream
+ *   CUDA stream in which device memory allocations may be made
  *
  * */
-matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
+matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
 {
   MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
   
@@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
                       &workspaceSize, this->params_.exec_type);
   MATX_ASSERT(error == CUFFT_SUCCESS, matxCufftError);
 
-  matxAlloc((void **)&this->workspace_, workspaceSize);
-  cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
+  matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
   cufftSetWorkArea(this->plan_, this->workspace_);
 
   error = cufftXtMakePlanMany(
@@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i,
  *   Output view data type
  * @tparam T2
  *   Input view data type
+ * @param stream
+ *   CUDA stream in which device memory allocations may be made
  */
 template <typename OutTensorType, typename InTensorType = OutTensorType>
 class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
@@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
    *   Input view
    *
    * */
-  matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i)
+  matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
   {
     static_assert(RANK >= 2, "2D FFTs require a rank-2 tensor or higher");
     
@@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
                        this->params_.output_type, this->params_.batch,
                        &workspaceSize, this->params_.exec_type);
 
-    matxAlloc((void **)&this->workspace_, workspaceSize);
-    cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
+    matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
     cufftSetWorkArea(this->plan_, this->workspace_);
 
     error = cufftXtMakePlanMany(
@@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i,
   // Get cache or new FFT plan if it doesn't exist
   auto ret = detail::cache_1d.Lookup(params);
   if (ret == std::nullopt) {
-    auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
+    auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
     detail::cache_1d.Insert(params, static_cast<void *>(tmp));
     tmp->Forward(out, in, stream, norm);
   }
@@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i,
   // Get cache or new FFT plan if it doesn't exist
   auto ret = detail::cache_1d.Lookup(params);
   if (ret == std::nullopt) {
-    auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
+    auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
     detail::cache_1d.Insert(params, static_cast<void *>(tmp));
     tmp->Inverse(out, in, stream, norm);
   }
@@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i,
   // Get cache or new FFT plan if it doesn't exist
   auto ret = detail::cache_2d.Lookup(params);
   if (ret == std::nullopt) {
-    auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in};
+    auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in, stream};
     detail::cache_2d.Insert(params, static_cast<void *>(tmp));
     tmp->Forward(out, in, stream);
   }
@@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
   // Get cache or new FFT plan if it doesn't exist
   auto ret = detail::cache_2d.Lookup(params);
   if (ret == std::nullopt) {
-    auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in};
+    auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in, stream};
     detail::cache_2d.Insert(params, static_cast<void *>(tmp));
     tmp->Inverse(out, in, stream);
   }