@@ -189,31 +189,81 @@ cpdef utils.dpnp_descriptor dpnp_identity(n, result_dtype):
189189 return result
190190
191191
192- # TODO this function should work through dpnp_arange_c
193- cpdef tuple dpnp_linspace(start, stop, num, endpoint, retstep, dtype, axis):
194- cdef shape_type_c obj_shape = utils._object_to_tuple(num)
195- cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(obj_shape, dtype, None )
192+ def dpnp_linspace (start , stop , num , dtype = None , device = None , usm_type = None , sycl_queue = None , endpoint = True , retstep = False , axis = 0 ):
193+ usm_type_alloc, sycl_queue_alloc = utils_py.get_usm_allocations([start, stop])
196194
197- if endpoint:
198- steps_count = num - 1
199- else :
200- steps_count = num
195+ # Get sycl_queue.
196+ if sycl_queue is None and device is None :
197+ sycl_queue = sycl_queue_alloc
198+ sycl_queue_normalized = dpnp.get_normalized_queue_device( sycl_queue = sycl_queue, device = device)
201199
202- # if there are steps, then fill values
203- if steps_count > 0 :
204- step = (dpnp.float64(stop) - start) / steps_count
205- for i in range (1 , result.size):
206- result.get_pyobj()[i] = start + step * i
200+ # Get temporary usm_type for getting dtype.
201+ if usm_type is None :
202+ _usm_type = " device" if usm_type_alloc is None else usm_type_alloc
207203 else :
208- step = dpnp.nan
204+ _usm_type = usm_type
205+
206+ # Get dtype.
207+ if not hasattr (start, " dtype" ) and not dpnp.isscalar(start):
208+ start = dpnp.asarray(start, usm_type = _usm_type, sycl_queue = sycl_queue_normalized)
209+ if not hasattr (stop, " dtype" ) and not dpnp.isscalar(stop):
210+ stop = dpnp.asarray(stop, usm_type = _usm_type, sycl_queue = sycl_queue_normalized)
211+ dt = numpy.result_type(start, stop, float (num))
212+ dt = utils_py.map_dtype_to_device(dt, sycl_queue_normalized.sycl_device)
213+ if dtype is None :
214+ dtype = dt
215+
216+ if dpnp.isscalar(start) and dpnp.isscalar(stop):
217+ # Call linspace() function for scalars.
218+ res = dpnp_container.linspace(start,
219+ stop,
220+ num,
221+ dtype = dt,
222+ usm_type = _usm_type,
223+ sycl_queue = sycl_queue_normalized,
224+ endpoint = endpoint)
225+ else :
226+ num = operator.index(num)
227+ if num < 0 :
228+ raise ValueError (" Number of points must be non-negative" )
229+
230+ # Get final usm_type and copy arrays if needed with current dtype, usm_type and sycl_queue.
231+ # Do not need to copy usm_ndarray by usm_type if it is not explicitly stated.
232+ if usm_type is None :
233+ usm_type = _usm_type
234+ if not hasattr (start, " usm_type" ):
235+ _start = dpnp.asarray(start, dtype = dt, usm_type = usm_type, sycl_queue = sycl_queue_normalized)
236+ else :
237+ _start = dpnp.asarray(start, dtype = dt, sycl_queue = sycl_queue_normalized)
238+ if not hasattr (stop, " usm_type" ):
239+ _stop = dpnp.asarray(stop, dtype = dt, usm_type = usm_type, sycl_queue = sycl_queue_normalized)
240+ else :
241+ _stop = dpnp.asarray(stop, dtype = dt, sycl_queue = sycl_queue_normalized)
242+ else :
243+ _start = dpnp.asarray(start, dtype = dt, usm_type = usm_type, sycl_queue = sycl_queue_normalized)
244+ _stop = dpnp.asarray(stop, dtype = dt, usm_type = usm_type, sycl_queue = sycl_queue_normalized)
209245
210- # if result is not empty, then fiil first and last elements
211- if num > 0 :
212- result.get_pyobj()[0 ] = start
213- if endpoint and result.size > 1 :
214- result.get_pyobj()[result.size - 1 ] = stop
246+ # FIXME: issue #1304. Mathematical operations with scalar don't follow data type.
247+ _num = dpnp.asarray((num - 1 ) if endpoint else num, dtype = dt, usm_type = usm_type, sycl_queue = sycl_queue_normalized)
248+
249+ step = (_stop - _start) / _num
250+
251+ res = dpnp_container.arange(0 ,
252+ stop = num,
253+ step = 1 ,
254+ dtype = dt,
255+ usm_type = usm_type,
256+ sycl_queue = sycl_queue_normalized)
257+
258+ res = res.reshape((- 1 ,) + (1 ,) * step.ndim)
259+ res = res * step + _start
260+
261+ if endpoint and num > 1 :
262+ res[- 1 ] = dpnp_container.full(step.shape, _stop)
215263
216- return (result.get_pyobj(), step)
264+ if numpy.issubdtype(dtype, dpnp.integer):
265+ dpnp.floor(res, out = res)
266+ return res.astype(dtype)
217267
218268
219269cpdef utils.dpnp_descriptor dpnp_logspace(start, stop, num, endpoint, base, dtype, axis):
0 commit comments