Skip to content

Commit c790a59

Browse files
committed
Move curve into shared memory.
1 parent d8aa83d commit c790a59

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

include/flamegpu/runtime/AgentFunction.cuh

+13-3
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,26 @@ __global__ void agent_function_wrapper(
6767
unsigned int *scanFlag_messageOutput,
6868
unsigned int *scanFlag_agentOutput) {
6969
// We place these at the start of shared memory, so we can locate it anywhere in device code without a reference
70+
using detail::sm;
7071
if (threadIdx.x == 0) {
71-
using detail::sm;
72+
#if !defined(SEATBELTS) || SEATBELTS
73+
sm()->device_exception = error_buffer;
74+
#endif
7275
#ifndef __CUDACC_RTC__
73-
sm()->curve = d_curve_table;
7476
sm()->env_buffer = d_env_buffer;
7577
#endif
78+
}
79+
#ifndef __CUDACC_RTC__
80+
for (int idx = threadIdx.x; idx < detail::curve::Curve::MAX_VARIABLES; idx += blockDim.x) {
81+
sm()->curve_variables[idx] = d_curve_table->variables[idx];
82+
sm()->curve_hashes[idx] = d_curve_table->hashes[idx];
7683
#if !defined(SEATBELTS) || SEATBELTS
77-
sm()->device_exception = error_buffer;
84+
sm()->curve_type_size[idx] = d_curve_table->type_size[idx];
85+
sm()->curve_elements[idx] = d_curve_table->elements[idx];
86+
sm()->curve_count[idx] = d_curve_table->count[idx];
7887
#endif
7988
}
89+
#endif
8090

8191
#if defined(__CUDACC__) // @todo - This should not be required. This template should only ever be processed by a CUDA compiler.
8292
// Sync the block after Thread 0 has written to shared.

include/flamegpu/runtime/AgentFunctionCondition.cuh

+13-3
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,26 @@ __global__ void agent_function_condition_wrapper(
4646
curandState *d_rng,
4747
unsigned int *scanFlag_conditionResult) {
4848
// We place these at the start of shared memory, so we can locate it anywhere in device code without a reference
49+
using detail::sm;
4950
if (threadIdx.x == 0) {
50-
using detail::sm;
51+
#if !defined(SEATBELTS) || SEATBELTS
52+
sm()->device_exception = error_buffer;
53+
#endif
5154
#ifndef __CUDACC_RTC__
52-
sm()->curve = d_curve_table;
5355
sm()->env_buffer = d_env_buffer;
5456
#endif
57+
}
58+
#ifndef __CUDACC_RTC__
59+
for (int idx = threadIdx.x; idx < detail::curve::Curve::MAX_VARIABLES; idx += blockDim.x) {
60+
sm()->curve_variables[idx] = d_curve_table->variables[idx];
61+
sm()->curve_hashes[idx] = d_curve_table->hashes[idx];
5562
#if !defined(SEATBELTS) || SEATBELTS
56-
sm()->device_exception = error_buffer;
63+
sm()->curve_type_size[idx] = d_curve_table->type_size[idx];
64+
sm()->curve_elements[idx] = d_curve_table->elements[idx];
65+
sm()->curve_count[idx] = d_curve_table->count[idx];
5766
#endif
5867
}
68+
#endif
5969

6070
#if defined(__CUDACC__) // @todo - This should not be required. This template should only ever be processed by a CUDA compiler.
6171
// Sync the block after Thread 0 has written to shared.

include/flamegpu/runtime/detail/SharedBlock.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
#ifndef INCLUDE_FLAMEGPU_RUNTIME_DETAIL_SHAREDBLOCK_H_
22
#define INCLUDE_FLAMEGPU_RUNTIME_DETAIL_SHAREDBLOCK_H_
33

4+
#include "flamegpu/runtime/detail/curve/Curve.cuh"
5+
46
namespace flamegpu {
57
namespace exception {
68
struct DeviceExceptionBuffer;
79
} // namespace exception
810
namespace detail {
9-
namespace curve {
10-
struct CurveTable;
11-
} // namespace curve
1211
/**
1312
* This struct represents the data we package into shared memory
1413
* The ifndef __CUDACC_RTC__ will cause the size to be too large for RTC builds, but that's not (currently) an issue
1514
*/
1615
struct SharedBlock {
1716
#ifndef __CUDACC_RTC__
18-
const curve::CurveTable* curve;
17+
curve::Curve::VariableHash curve_hashes[curve::Curve::MAX_VARIABLES];
18+
char* curve_variables[curve::Curve::MAX_VARIABLES];
19+
#if !defined(SEATBELTS) || SEATBELTS
20+
unsigned int curve_type_size[curve::Curve::MAX_VARIABLES];
21+
unsigned int curve_elements[curve::Curve::MAX_VARIABLES];
22+
unsigned int curve_count[curve::Curve::MAX_VARIABLES];
23+
#endif
1924
const char* env_buffer;
2025
#endif
2126
#if !defined(SEATBELTS) || SEATBELTS

include/flamegpu/runtime/detail/curve/DeviceCurve.cuh

+8-8
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ __device__ __forceinline__ DeviceCurve::Variable DeviceCurve::getVariableIndex(c
239239
// (This may inflate register usage based on the max number of iterations in some cases)
240240
for (unsigned int x = 0; x< MAX_VARIABLES; x++) {
241241
const Variable i = (variable_hash + x) & (MAX_VARIABLES - 1);
242-
if (sm()->curve->hashes[i] == variable_hash)
242+
if (sm()->curve_hashes[i] == variable_hash)
243243
return i;
244244
}
245245
return UNKNOWN_VARIABLE;
@@ -252,19 +252,19 @@ __device__ __forceinline__ char* DeviceCurve::getVariablePtr(const char(&variabl
252252
if (cv == UNKNOWN_VARIABLE) {
253253
DTHROW("Curve variable with name '%s' was not found.\n", variableName);
254254
return nullptr;
255-
} else if (sm()->curve->type_size[cv] != sizeof(typename type_decode<T>::type_t)) {
256-
DTHROW("Curve variable with name '%s', type size mismatch %u != %llu.\n", variableName, sm()->curve->type_size[cv], sizeof(typename type_decode<T>::type_t));
255+
} else if (sm()->curve_type_size[cv] != sizeof(typename type_decode<T>::type_t)) {
256+
DTHROW("Curve variable with name '%s', type size mismatch %u != %llu.\n", variableName, sm()->curve_type_size[cv], sizeof(typename type_decode<T>::type_t));
257257
return nullptr;
258-
} else if (!(sm()->curve->elements[cv] == type_decode<T>::len_t * N || (namespace_hash == Curve::variableHash("_environment") && N == 0))) { // Special case, environment can avoid specifying N
259-
DTHROW("Curve variable with name '%s', variable array length mismatch %u != %u.\n", variableName, sm()->curve->elements[cv], type_decode<T>::len_t);
258+
} else if (!(sm()->curve_elements[cv] == type_decode<T>::len_t * N || (namespace_hash == Curve::variableHash("_environment") && N == 0))) { // Special case, environment can avoid specifying N
259+
DTHROW("Curve variable with name '%s', variable array length mismatch %u != %u.\n", variableName, sm()->curve_elements[cv], type_decode<T>::len_t);
260260
return nullptr;
261-
} else if (offset >= sm()->curve->type_size[cv] * sm()->curve->elements[cv] * sm()->curve->count[cv]) { // Note : offset is basically index * sizeof(T)
262-
DTHROW("Curve variable with name '%s', offset exceeds buffer length %u >= %u.\n", offset, sm()->curve->type_size[cv] * sm()->curve->elements[cv] * sm()->curve->count[cv]);
261+
} else if (offset >= sm()->curve_type_size[cv] * sm()->curve_elements[cv] * sm()->curve_count[cv]) { // Note : offset is basically index * sizeof(T)
262+
DTHROW("Curve variable with name '%s', offset exceeds buffer length %u >= %u.\n", offset, sm()->curve_type_size[cv] * sm()->curve_elements[cv] * sm()->curve_count[cv]);
263263
return nullptr;
264264
}
265265
#endif
266266
// return a generic pointer to variable address for given offset
267-
return sm()->curve->variables[cv] + offset;
267+
return sm()->curve_variables[cv] + offset;
268268
}
269269
////
270270
//// Middle Layer CURVE API

0 commit comments

Comments
 (0)