Skip to content

Commit 5cab19e

Browse files
authored
Merge branch 'main' into branch-3
2 parents f5efa23 + 0a41c60 commit 5cab19e

File tree

28 files changed

+1456
-234
lines changed

28 files changed

+1456
-234
lines changed

.buildkite/pipeline.yml

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
steps:
22
- group: ":test_tube: Tests"
33
steps:
4-
- label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
4+
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
55
matrix:
66
setup:
77
version:
@@ -33,7 +33,7 @@ steps:
3333
env:
3434
REACTANT_TEST_GROUP: "{{matrix.group}}"
3535
if: build.message !~ /\[skip tests\]/
36-
timeout_in_minutes: 60
36+
timeout_in_minutes: 120
3737

3838
- label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}"
3939
matrix:
@@ -70,78 +70,78 @@ steps:
7070
env:
7171
REACTANT_TEST_GROUP: "{{matrix.group}}"
7272
if: build.message !~ /\[skip tests\]/
73-
timeout_in_minutes: 60
73+
timeout_in_minutes: 120
7474

75-
- group: ":racehorse: Benchmarks"
76-
steps:
77-
- label: "CPU: Run Benchmarks"
78-
plugins:
79-
- JuliaCI/julia#v1:
80-
version: "1"
81-
command: |
82-
julia --project=benchmark -e 'println("--- :julia: Instantiating project")
83-
using Pkg
84-
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
75+
# - group: ":racehorse: Benchmarks"
76+
# steps:
77+
# - label: "CPU: Run Benchmarks"
78+
# plugins:
79+
# - JuliaCI/julia#v1:
80+
# version: "1"
81+
# command: |
82+
# julia --project=benchmark -e 'println("--- :julia: Instantiating project")
83+
# using Pkg
84+
# Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
8585

86-
julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
87-
include("benchmark/runbenchmarks.jl")'
88-
artifact_paths:
89-
- "benchmark/results/*"
90-
agents:
91-
# Models are quite large so we need a decent sized machine. Don't tell Chris we
92-
# are stealing SciMLBenchmarks machine :P
93-
queue: "juliaecosystem"
94-
sandbox_capable: true
95-
exclusive: true
96-
arch: "x86_64"
97-
env:
98-
BENCHMARK_GROUP: CPU
99-
JULIA_NUM_THREADS: "auto"
100-
timeout_in_minutes: 120
86+
# julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
87+
# include("benchmark/runbenchmarks.jl")'
88+
# artifact_paths:
89+
# - "benchmark/results/*"
90+
# agents:
91+
# # Models are quite large so we need a decent sized machine. Don't tell Chris we
92+
# # are stealing SciMLBenchmarks machine :P
93+
# queue: "juliaecosystem"
94+
# sandbox_capable: true
95+
# exclusive: true
96+
# arch: "x86_64"
97+
# env:
98+
# BENCHMARK_GROUP: CPU
99+
# JULIA_NUM_THREADS: "auto"
100+
# timeout_in_minutes: 120
101101

102-
- label: "CUDA: Run Benchmarks"
103-
plugins:
104-
- JuliaCI/julia#v1:
105-
version: "1"
106-
command: |
107-
julia --project=benchmark -e 'println("--- :julia: Instantiating project")
108-
using Pkg
109-
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
102+
# - label: "CUDA: Run Benchmarks"
103+
# plugins:
104+
# - JuliaCI/julia#v1:
105+
# version: "1"
106+
# command: |
107+
# julia --project=benchmark -e 'println("--- :julia: Instantiating project")
108+
# using Pkg
109+
# Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
110110

111-
julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
112-
include("benchmark/runbenchmarks.jl")'
113-
artifact_paths:
114-
- "benchmark/results/*"
115-
agents:
116-
queue: "benchmark"
117-
gpu: "rtx4070"
118-
cuda: "*"
119-
env:
120-
BENCHMARK_GROUP: CUDA
121-
JULIA_NUM_THREADS: "auto"
122-
timeout_in_minutes: 120
111+
# julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
112+
# include("benchmark/runbenchmarks.jl")'
113+
# artifact_paths:
114+
# - "benchmark/results/*"
115+
# agents:
116+
# queue: "benchmark"
117+
# gpu: "rtx4070"
118+
# cuda: "*"
119+
# env:
120+
# BENCHMARK_GROUP: CUDA
121+
# JULIA_NUM_THREADS: "auto"
122+
# timeout_in_minutes: 120
123123

124-
- wait: ~
125-
continue_on_failure: true
124+
# - wait: ~
125+
# continue_on_failure: true
126126

127-
- label: "Combine benchmarks"
128-
plugins:
129-
- JuliaCI/julia#v1:
130-
version: "1"
131-
command: |
132-
buildkite-agent artifact download "benchmark/results/*" .
127+
# - label: "Combine benchmarks"
128+
# plugins:
129+
# - JuliaCI/julia#v1:
130+
# version: "1"
131+
# command: |
132+
# buildkite-agent artifact download "benchmark/results/*" .
133133

134-
julia -e 'println("--- :julia: Instantiating project")
135-
using Pkg
136-
Pkg.add("BenchmarkTools")
134+
# julia -e 'println("--- :julia: Instantiating project")
135+
# using Pkg
136+
# Pkg.add("BenchmarkTools")
137137

138-
println("--- :julia: Combining Benchmarks")
139-
include("benchmark/aggregate.jl")'
140-
artifact_paths:
141-
- "benchmark/results/combinedbenchmarks.json"
142-
agents:
143-
queue: "juliagpu"
144-
timeout_in_minutes: 10
138+
# println("--- :julia: Combining Benchmarks")
139+
# include("benchmark/aggregate.jl")'
140+
# artifact_paths:
141+
# - "benchmark/results/combinedbenchmarks.json"
142+
# agents:
143+
# queue: "juliagpu"
144+
# timeout_in_minutes: 10
145145

146146
# - label: "AMDGPU Julia v{{matrix.version}}"
147147
# matrix:

.github/workflows/CI.yml

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
version:
2424
- '1.10'
2525
- '1.11'
26-
- 'nightly'
26+
# - 'nightly'
2727
os:
2828
- ubuntu-20.04
2929
- macOS-latest
@@ -51,20 +51,26 @@ jobs:
5151
assertions: true
5252
test_group: neural_networks
5353
- os: ubuntu-20.04
54-
arch: x86
55-
libReactant: packaged
56-
version: '1.10'
57-
test_group: core
58-
- os: ubuntu-20.04
59-
arch: x86
60-
libReactant: packaged
61-
version: '1.10'
62-
test_group: neural_networks
63-
- os: ubuntu-20.04
64-
arch: x86
54+
arch: x64
6555
libReactant: packaged
6656
version: '1.10'
57+
assertions: true
6758
test_group: integration
59+
# - os: ubuntu-20.04
60+
# arch: x86
61+
# libReactant: packaged
62+
# version: '1.10'
63+
# test_group: core
64+
# - os: ubuntu-20.04
65+
# arch: x86
66+
# libReactant: packaged
67+
# version: '1.10'
68+
# test_group: neural_networks
69+
# - os: ubuntu-20.04
70+
# arch: x86
71+
# libReactant: packaged
72+
# version: '1.10'
73+
# test_group: integration
6874
exclude:
6975
# these are run on Buildkite
7076
- os: ubuntu-20.04

Project.toml

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
4-
version = "0.2.10"
4+
version = "0.2.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -15,32 +15,38 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1616
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1717
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
1920
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
2021
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2122

2223
[weakdeps]
2324
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2425
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
26+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2527
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
28+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2629
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2730
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2831

29-
[sources.ReactantCore]
30-
path = "lib/ReactantCore"
32+
[sources]
33+
ReactantCore = {path = "lib/ReactantCore"}
3134

3235
[extensions]
3336
ReactantAbstractFFTsExt = "AbstractFFTs"
3437
ReactantArrayInterfaceExt = "ArrayInterface"
38+
ReactantCUDAExt = "CUDA"
3539
ReactantNNlibExt = "NNlib"
40+
ReactantRandom123Ext = "Random123"
3641
ReactantStatisticsExt = "Statistics"
3742
ReactantYaoBlocksExt = "YaoBlocks"
3843

3944
[compat]
4045
AbstractFFTs = "1.5"
41-
Adapt = "4"
42-
ArrayInterface = "7.10"
43-
CEnum = "0.4, 0.5"
46+
Adapt = "4.1"
47+
ArrayInterface = "7.17.1"
48+
CEnum = "0.5"
49+
CUDA = "5.5"
4450
Downloads = "1.6"
4551
Enzyme = "0.13.22"
4652
EnzymeCore = "0.8.8"
@@ -50,8 +56,10 @@ NNlib = "0.9.26"
5056
OrderedCollections = "1"
5157
PrecompileTools = "1"
5258
Preferences = "1.4"
59+
Random = "1.10"
60+
Random123 = "1.7"
5361
ReactantCore = "0.1.3"
54-
Reactant_jll = "0.0.26"
62+
Reactant_jll = "0.0.27"
5563
Scratch = "1.2"
5664
Statistics = "1.10"
5765
YaoBlocks = "0.13"
@@ -60,4 +68,5 @@ julia = "1.10"
6068
[extras]
6169
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
6270
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
71+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6372
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

deps/ReactantExtra/API.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
470470
context.loadDialect<mlir::stablehlo::StablehloDialect>();
471471
context.loadDialect<mlir::chlo::ChloDialect>();
472472
}
473+
474+
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
473477
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
474478
mlir::DialectRegistry &registry = *unwrap(creg);
475479

@@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
513517
mlir::affine::registerAffinePasses();
514518
mlir::registerReconcileUnrealizedCasts();
515519

520+
mlir::registerLLVMDialectImport(registry);
521+
mlir::registerNVVMDialectImport(registry);
522+
523+
mlir::LLVM::registerInlinerInterface(registry);
524+
516525
/*
517526
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
518527
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
540549
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
541550
}
542551

552+
553+
/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+
/// suffix in `lastUsedID`.
555+
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
556+
unsigned &lastUsedID,
557+
mlir::ModuleOp source,
558+
mlir::ModuleOp target) {
559+
using namespace llvm;
560+
using namespace mlir;
561+
SmallString<64> newSymName(oldSymName);
562+
newSymName.push_back('_');
563+
while (true) {
564+
auto possible = newSymName + Twine(++lastUsedID);
565+
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
566+
return StringAttr::get(target.getContext(), possible);
567+
}
568+
}
569+
}
570+
571+
572+
/// Checks if a symbol with the same name as `op` already exists in `source`.
573+
/// If so, renames `op` and updates all its references in `target`.
574+
static mlir::LogicalResult
575+
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+
unsigned &lastUsedID) {
577+
using namespace llvm;
578+
using namespace mlir;
579+
580+
auto opName = op.getName().str();
581+
582+
if (!SymbolTable::lookupSymbolIn(target, opName)) {
583+
return success();
584+
}
585+
586+
StringAttr newSymName =
587+
renameSymbol(opName, lastUsedID, source, target);
588+
589+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
590+
return op.emitError("unable to update all symbol uses for ")
591+
<< opName << " to " << newSymName;
592+
593+
SymbolTable::setSymbolName(op, newSymName);
594+
return success();
595+
}
596+
597+
extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
598+
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
599+
auto newMod = cast<ModuleOp>(*unwrap(newModC));
600+
601+
Operation* entryFn = nullptr;
602+
603+
unsigned lastUsedID = 0;
604+
605+
for (auto &op : *newMod.getBody()) {
606+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+
if (!symbolOp)
608+
continue;
609+
610+
StringRef oldSymName = symbolOp.getName();
611+
612+
if (oldSymName == entryfn) {
613+
entryFn = &op;
614+
}
615+
616+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
617+
lastUsedID))) {
618+
assert(0 && "failed to update all uses");
619+
}
620+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
621+
}
622+
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
623+
newMod.getBody()->getOperations());
624+
return wrap(entryFn);
625+
}
626+
543627
#pragma region xla::ifrt
544628

545629
#pragma region xla::ifrt::Value

0 commit comments

Comments
 (0)