Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update federated secure branch. #10622

Merged
merged 52 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d33043a
[coll] Allow using local host for testing. (#10526)
trivialfis Jul 2, 2024
e537b09
Fix boolean array for arrow-backed DF. (#10527)
trivialfis Jul 2, 2024
9cb4c93
[EM] Move prefetch in reset into the end of the iteration. (#10529)
trivialfis Jul 2, 2024
628411a
Enhance the threadpool implementation. (#10531)
trivialfis Jul 3, 2024
6243e7c
[doc] Update link to release notes. [skip ci] (#10533)
trivialfis Jul 3, 2024
cd1d108
[doc] Fix learning to rank tutorial. [skip ci] (#10539)
trivialfis Jul 3, 2024
620b2b1
Cache GPU histogram kernel configuration. (#10538)
trivialfis Jul 4, 2024
513d7a7
[sycl] Reorder if-else statements to allow using of cpu branches for …
razdoburdin Jul 5, 2024
00264eb
[EM] Basic distributed test for external memory. (#10492)
trivialfis Jul 5, 2024
0a3941b
[sycl] Improve build configuration. (#10548)
razdoburdin Jul 6, 2024
2266db1
[R] Update roxygen. (#10556)
trivialfis Jul 8, 2024
8d0f2bf
[doc] Add more detailed explanations for advanced objectives (#10283)
david-cortes Jul 8, 2024
3ec74a1
[doc] Add `build_info` to autodoc. [skip ci] (#10551)
trivialfis Jul 9, 2024
8e2b874
[doc] Add notes about RMM and device ordinal. [skip ci] (#10562)
trivialfis Jul 10, 2024
baba3e9
Fix empty partition. (#10559)
trivialfis Jul 10, 2024
34b154c
Avoid the use of size_t in the partitioner. (#10541)
trivialfis Jul 10, 2024
5f910cd
[EM] Handle base idx in GPU histogram. (#10549)
trivialfis Jul 10, 2024
89da9f9
[fed] Split up federated test CMake file. (#10566)
trivialfis Jul 11, 2024
1ca4bfd
Avoid thrust vector initialization. (#10544)
trivialfis Jul 11, 2024
6c40318
Fix column split race condition. (#10572)
trivialfis Jul 11, 2024
5fea9d2
Small cleanup for CMake scripts. (#10573)
trivialfis Jul 11, 2024
ce97de2
replace channel for sycl dependencies (#10576)
razdoburdin Jul 12, 2024
6fc1088
Bump org.apache.maven.plugins:maven-project-info-reports-plugin (#10497)
dependabot[bot] Jul 14, 2024
5b7c689
Bump org.apache.flink:flink-clients in /jvm-packages (#10517)
dependabot[bot] Jul 14, 2024
7996914
Bump org.apache.maven.plugins:maven-surefire-plugin (#10429)
dependabot[bot] Jul 14, 2024
8b77964
Bump commons-logging:commons-logging in /jvm-packages/xgboost4j-spark…
dependabot[bot] Jul 15, 2024
0f789e2
Bump org.apache.maven.plugins:maven-jar-plugin (#10458)
dependabot[bot] Jul 15, 2024
5b68b68
Bump org.apache.maven.plugins:maven-project-info-reports-plugin (#10585)
dependabot[bot] Jul 15, 2024
a81ccab
Bump org.apache.maven.plugins:maven-release-plugin (#10586)
dependabot[bot] Jul 15, 2024
b7511cb
Bump net.alchim31.maven:scala-maven-plugin in /jvm-packages/xgboost4j…
dependabot[bot] Jul 15, 2024
17c6430
Bump org.apache.maven.plugins:maven-checkstyle-plugin in /jvm-package…
dependabot[bot] Jul 15, 2024
ab982e7
[R] Redesigned `xgboost()` interface skeleton (#10456)
david-cortes Jul 15, 2024
bbd3085
[jvm-packages] Bump rapids version. (#10588)
trivialfis Jul 15, 2024
fa8fea1
Bump scalatest.version from 3.2.18 to 3.2.19 in /jvm-packages/xgboost…
dependabot[bot] Jul 15, 2024
370dce9
[Doc] Fix CRAN badge in README [skip ci] (#10587)
RektPunk Jul 15, 2024
5a92ffe
Partial fix for CTK 12.5 (#10574)
trivialfis Jul 16, 2024
a6a8a55
Merge approx tests. (#10583)
trivialfis Jul 16, 2024
ee8bb60
[CI] Reduce the frequency of dependabot PRs (#10593)
hcho3 Jul 17, 2024
c41a657
Bump actions/setup-python from 5.1.0 to 5.1.1 (#10599)
dependabot[bot] Jul 17, 2024
919cfd9
Bump actions/upload-artifact from 4.3.3 to 4.3.4 (#10600)
dependabot[bot] Jul 17, 2024
07732e0
Bump com.fasterxml.jackson.core:jackson-databind (#10590)
dependabot[bot] Jul 17, 2024
e9fbce9
Refactor `DeviceUVector`. (#10595)
trivialfis Jul 17, 2024
292bb67
[EM] Support mmap backed ellpack. (#10602)
trivialfis Jul 18, 2024
7ab93f3
[CI] Fix test environment. (#10609)
trivialfis Jul 18, 2024
326921d
[CI] Build a CPU-only wheel under name `xgboost-cpu` (#10603)
hcho3 Jul 19, 2024
344ddeb
Drop support for CUDA legacy stream. (#10607)
trivialfis Jul 19, 2024
0846ad8
Optionally skip cupy on windows. (#10611)
trivialfis Jul 20, 2024
cb62f9e
[EM] Prevent init with CUDA malloc resource. (#10606)
trivialfis Jul 20, 2024
6d9fcb7
Move device histogram storage into `histogram.cuh`. (#10608)
trivialfis Jul 21, 2024
b039cd3
Merge branch 'master' into update-fed
trivialfis Jul 22, 2024
ab022ee
Fix.
trivialfis Jul 22, 2024
1f658b2
Fix.
trivialfis Jul 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j"
schedule:
interval: "daily"
interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-gpu"
schedule:
Expand All @@ -24,7 +24,7 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark"
schedule:
interval: "daily"
interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark-gpu"
schedule:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/i386.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
with:
submodules: 'true'
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v3.4.0
with:
driver-opts: network=host
- name: Build and push container
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ jobs:
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
with:
submodules: 'true'
- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: "3.8"
architecture: 'x64'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ jobs:
submodules: 'true'

- name: Set up Python 3.8
uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: 3.8

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/r_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
key: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}
restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}

- uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: "3.8"
architecture: 'x64'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scorecards.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4
with:
name: SARIF file
path: results.sarif
Expand Down
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" ON)
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
# This is specifically designed for PyPI binary release and should be disabled for most of the cases.
option(USE_DLOPEN_NCCL "Whether to load nccl dynamically." OFF)
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ OBJECTS= \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
$(PKGROOT)/src/common/common.o \
$(PKGROOT)/src/common/cuda_rt_utils.o \
$(PKGROOT)/src/common/error_msg.o \
$(PKGROOT)/src/common/hist_util.o \
$(PKGROOT)/src/common/host_device_vector.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ OBJECTS= \
$(PKGROOT)/src/common/charconv.o \
$(PKGROOT)/src/common/column_matrix.o \
$(PKGROOT)/src/common/common.o \
$(PKGROOT)/src/common/cuda_rt_utils.o \
$(PKGROOT)/src/common/error_msg.o \
$(PKGROOT)/src/common/hist_util.o \
$(PKGROOT)/src/common/host_device_vector.o \
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
[![Build Status](https://badge.buildkite.com/aca47f40a32735c00a8550540c5eeff6a4c1d246a580cae9b0.svg?branch=master)](https://buildkite.com/xgboost/xgboost-ci)
[![XGBoost-CI](https://github.com/dmlc/xgboost/workflows/XGBoost-CI/badge.svg?branch=master)](https://github.com/dmlc/xgboost/actions)
[![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
[![CRAN Status Badge](http://www.r-pkg.org/badges/version/xgboost)](http://cran.r-project.org/web/packages/xgboost)
[![GitHub license](https://dmlc.github.io/img/apache2.svg)](./LICENSE)
[![CRAN Status Badge](https://www.r-pkg.org/badges/version/xgboost)](https://cran.r-project.org/web/packages/xgboost)
[![PyPI version](https://badge.fury.io/py/xgboost.svg)](https://pypi.python.org/pypi/xgboost/)
[![Conda version](https://img.shields.io/conda/vn/conda-forge/py-xgboost.svg)](https://anaconda.org/conda-forge/py-xgboost)
[![Optuna](https://img.shields.io/badge/Optuna-integrated-blue)](https://optuna.org)
Expand Down Expand Up @@ -35,7 +35,7 @@ Checkout the [Community Page](https://xgboost.ai/community).

Reference
---------
- Tianqi Chen and Carlos Guestrin. [XGBoost: A Scalable Tree Boosting System](http://arxiv.org/abs/1603.02754). In 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016
- Tianqi Chen and Carlos Guestrin. [XGBoost: A Scalable Tree Boosting System](https://arxiv.org/abs/1603.02754). In 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016
- XGBoost originates from research project at University of Washington.

Sponsors
Expand Down
8 changes: 2 additions & 6 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,8 @@ function(xgboost_set_cuda_flags target)
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=${OpenMP_CXX_FLAGS}>
$<$<COMPILE_LANGUAGE:CUDA>:-Xfatbin=-compress-all>)

if(USE_PER_THREAD_DEFAULT_STREAM)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--default-stream per-thread>)
endif()
$<$<COMPILE_LANGUAGE:CUDA>:-Xfatbin=-compress-all>
$<$<COMPILE_LANGUAGE:CUDA>:--default-stream per-thread>)

if(FORCE_COLORED_OUTPUT)
if(FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND
Expand Down
14 changes: 14 additions & 0 deletions dev/release-artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

tqdm, sh are required to run this script.
"""

import argparse
import os
import shutil
Expand Down Expand Up @@ -106,6 +107,15 @@ def make_pysrc_wheel(
if not os.path.exists(dist):
os.mkdir(dist)

# Apply patch to remove NCCL dependency
# Save the original content of pyproject.toml so that we can restore it later
with DirectoryExcursion(ROOT):
with open("python-package/pyproject.toml", "r") as f:
orig_pyproj_lines = f.read()
with open("tests/buildkite/remove_nccl_dep.patch", "r") as f:
patch_lines = f.read()
subprocess.run(["patch", "-p0"], input=patch_lines, text=True)

with DirectoryExcursion(os.path.join(ROOT, "python-package")):
subprocess.check_call(["python", "-m", "build", "--sdist"])
if rc is not None:
Expand All @@ -117,6 +127,10 @@ def make_pysrc_wheel(
target = os.path.join(dist, name)
shutil.move(src, target)

with DirectoryExcursion(ROOT):
with open("python-package/pyproject.toml", "w") as f:
print(orig_pyproj_lines, file=f, end="")


def download_py_packages(
branch: str, major: int, minor: int, commit_hash: str, outdir: str
Expand Down
13 changes: 13 additions & 0 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ Capabilities of binary wheels for each platform:
| Windows | |tick| | |cross| |
+---------------------+---------+----------------------+

Minimal installation (CPU-only)
*******************************
The default installation with ``pip`` will install the full XGBoost package, including the support for the GPU algorithms and federated learning.

You may choose to reduce the size of the installed package and save the disk space, by opting to install ``xgboost-cpu`` instead:

.. code-block:: bash

pip install xgboost-cpu

The ``xgboost-cpu`` variant will have drastically smaller disk footprint, but does not provide some features, such as the GPU algorithms and
federated learning.

Conda
*****

Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#define XGBOOST_BASE_H_

#include <dmlc/omp.h> // for omp_uint, omp_ulong
// Put the windefs here to guard as many files as possible.
#include <xgboost/windefs.h>

#include <cstdint> // for int32_t, uint64_t, int16_t
#include <ostream> // for ostream
Expand Down
7 changes: 4 additions & 3 deletions include/xgboost/collective/poll_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
* \author Tianqi Chen
*/
#pragma once
#include "xgboost/collective/result.h"
#include "xgboost/collective/socket.h"
#include <xgboost/collective/result.h>
#include <xgboost/collective/socket.h>

#if defined(_WIN32)
#include <xgboost/windefs.h>
// Socket API
#include <winsock2.h>
#include <ws2tcpip.h>

#else

#include <arpa/inet.h>
Expand Down
28 changes: 7 additions & 21 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
/**
* Copyright (c) 2022-2024, XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#pragma once

#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)

#include <cerrno> // errno, EINTR, EBADF
#include <climits> // HOST_NAME_MAX
#include <cstddef> // std::size_t
Expand All @@ -18,18 +14,12 @@

#if defined(__linux__)
#include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
#endif // defined(__linux__)

#if !defined(xgboost_IS_MINGW)

#if defined(__MINGW32__)
#define xgboost_IS_MINGW 1
#endif // defined(__MINGW32__)

#endif // xgboost_IS_MINGW
#endif // defined(__linux__)

#if defined(_WIN32)

// Guard the include.
#include <xgboost/windefs.h>
// Socket API
#include <winsock2.h>
#include <ws2tcpip.h>

Expand All @@ -41,9 +31,9 @@ using in_port_t = std::uint16_t;

#if !defined(xgboost_IS_MINGW)
using ssize_t = int;
#endif // !xgboost_IS_MINGW()
#endif // !xgboost_IS_MINGW()

#else // UNIX
#else // UNIX

#include <arpa/inet.h> // inet_ntop
#include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
Expand Down Expand Up @@ -839,7 +829,3 @@ Result INetNToP(H const &host, std::string *p_out) {
} // namespace xgboost

#undef xgboost_CHECK_SYS_CALL

#if defined(xgboost_IS_MINGW)
#undef xgboost_IS_MINGW
#endif
33 changes: 33 additions & 0 deletions include/xgboost/windefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* Copyright 2024, XGBoost Contributors
*
* @brief Macro for Windows.
*/
#pragma once

#if !defined(xgboost_IS_WIN)

#if defined(_MSC_VER) || defined(__MINGW32__)
#define xgboost_IS_WIN 1
#endif // defined(_MSC_VER) || defined(__MINGW32__)

#endif // !defined(xgboost_IS_WIN)

#if defined(xgboost_IS_WIN)

#if !defined(NOMINMAX)
#define NOMINMAX
#endif // !defined(NOMINMAX)

// A macro used inside `windows.h` to avoid conflicts with `winsock2.h`
#define WIN32_LEAN_AND_MEAN

#if !defined(xgboost_IS_MINGW)

#if defined(__MINGW32__)
#define xgboost_IS_MINGW 1
#endif // defined(__MINGW32__)

#endif // xgboost_IS_MINGW

#endif // defined(xgboost_IS_WIN)
8 changes: 4 additions & 4 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
<junit.version>4.13.2</junit.version>
<spark.version>3.5.1</spark.version>
<spark.version.gpu>3.5.1</spark.version.gpu>
<fasterxml.jackson.version>2.15.2</fasterxml.jackson.version>
<fasterxml.jackson.version>2.17.2</fasterxml.jackson.version>
<scala.version>2.12.18</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.4.0</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>24.04.0</cudf.version>
<spark.rapids.version>24.04.1</spark.rapids.version>
<cudf.version>24.06.0</cudf.version>
<spark.rapids.version>24.06.0</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version>
<scalatest.version>3.2.19</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
<skip.native.build>false</skip.native.build>

Expand Down
6 changes: 3 additions & 3 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ build-backend = "packager.pep517"

[project]
name = "xgboost"
version = "2.2.0-dev"
description = "XGBoost Python Package"
readme = { file = "README.rst", content-type = "text/x-rst" }
authors = [
{ name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu" },
{ name = "Jiaming Yuan", email = "jm.yuan@outlook.com" }
]
description = "XGBoost Python Package"
readme = { file = "README.rst", content-type = "text/x-rst" }
version = "2.2.0-dev"
requires-python = ">=3.8"
license = { text = "Apache-2.0" }
classifiers = [
Expand Down
41 changes: 11 additions & 30 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
get_cancer,
get_digits,
get_sparse,
make_batches,
memory,
)

Expand Down Expand Up @@ -161,7 +162,16 @@ def no_cudf() -> PytestSkip:


def no_cupy() -> PytestSkip:
return no_mod("cupy")
skip_cupy = no_mod("cupy")
if not skip_cupy["condition"] and system() == "Windows":
import cupy as cp

# Cupy might run into issue on Windows due to missing compiler
try:
cp.array([1, 2, 3]).sum()
except Exception: # pylint: disable=broad-except
skip_cupy["condition"] = True
return skip_cupy


def no_dask_cudf() -> PytestSkip:
Expand Down Expand Up @@ -248,35 +258,6 @@ def as_arrays(
return X, y, w


def make_batches( # pylint: disable=too-many-arguments,too-many-locals
n_samples_per_batch: int,
n_features: int,
n_batches: int,
use_cupy: bool = False,
*,
vary_size: bool = False,
random_state: int = 1994,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
X = []
y = []
w = []
if use_cupy:
import cupy

rng = cupy.random.RandomState(random_state)
else:
rng = np.random.RandomState(random_state)
for i in range(n_batches):
n_samples = n_samples_per_batch + i * 10 if vary_size else n_samples_per_batch
_X = rng.randn(n_samples, n_features)
_y = rng.randn(n_samples)
_w = rng.uniform(low=0, high=1, size=n_samples)
X.append(_X)
y.append(_y)
w.append(_w)
return X, y, w


def make_regression(
n_samples: int, n_features: int, use_cupy: bool
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
Expand Down
Loading
Loading