Skip to content

Commit

Permalink
feat: Array indexing + bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanmenzel committed Oct 30, 2024
1 parent 436fc91 commit 818f1aa
Show file tree
Hide file tree
Showing 24 changed files with 3,050 additions and 1,664 deletions.
71 changes: 65 additions & 6 deletions packages/algo-ts/src/arc4/encoded-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export class Bool {
}
}

abstract class Arc4ReadonlyArray<TItem> extends AbiEncoded {
abstract class Arc4ReadonlyArray<TItem extends AbiEncoded> extends AbiEncoded {
protected items: TItem[]
protected constructor(items: TItem[]) {
super()
Expand All @@ -85,31 +85,77 @@ abstract class Arc4ReadonlyArray<TItem> extends AbiEncoded {
},
})
}

/**
* Returns the current length of this array
*/
get length(): uint64 {
return Uint64(this.items.length)
}

/**
* Returns the item at the given index.
* Negative indexes are taken from the end.
* @param index The index of the item to retrieve
*/
at(index: Uint64Compat): TItem {
return arrayUtil.arrayAt(this.items, index)
}
slice(start: Uint64Compat, end: Uint64Compat): DynamicArray<TItem> {

/** @internal
* Create a new Dynamic array with all items from this array
*/
slice(): DynamicArray<TItem>
/** @internal
* Create a new DynamicArray with all items up till `end`.
* Negative indexes are taken from the end.
* @param end An index in which to stop copying items.
*/
slice(end: Uint64Compat): DynamicArray<TItem>
/** @internal
* Create a new DynamicArray with items from `start`, up until `end`
* Negative indexes are taken from the end.
* @param start An index in which to start copying items.
* @param end An index in which to stop copying items
*/
slice(start: Uint64Compat, end: Uint64Compat): DynamicArray<TItem>
slice(start?: Uint64Compat, end?: Uint64Compat): DynamicArray<TItem> {
return new DynamicArray(...arrayUtil.arraySlice(this.items, start, end))
}

/**
* Returns an iterator for the items in this array
*/
[Symbol.iterator](): IterableIterator<TItem> {
return this.items[Symbol.iterator]()
}

/**
* Returns an iterator for a tuple of the indexes and items in this array
*/
*entries(): IterableIterator<readonly [uint64, TItem]> {
for (const [idx, item] of this.items.entries()) {
yield [Uint64(idx), item]
}
}

/**
* Returns an iterator for the indexes in this array
*/
*keys(): IterableIterator<uint64> {
for (const idx of this.items.keys()) {
yield Uint64(idx)
}
}

/**
* Get or set the item at the specified index.
* Negative indexes are not supported
*/
[index: uint64]: TItem
}

export class StaticArray<TItem, TLength extends number> extends Arc4ReadonlyArray<TItem> {
export class StaticArray<TItem extends AbiEncoded, TLength extends number> extends Arc4ReadonlyArray<TItem> {
constructor()
constructor(...items: TItem[] & { length: TLength })
constructor(...items: TItem[])
Expand All @@ -122,13 +168,26 @@ export class StaticArray<TItem, TLength extends number> extends Arc4ReadonlyArra
}
}

export class DynamicArray<TItem> extends Arc4ReadonlyArray<TItem> {
export class DynamicArray<TItem extends AbiEncoded> extends Arc4ReadonlyArray<TItem> {
constructor(...items: TItem[]) {
super(items)
}
push(...items: TItem[]): void {}

/**
* Push a number of items into this array
* @param items The items to be added to this array
*/
push(...items: TItem[]): void {
this.items.push(...items)
}

/**
* Pop a single item from this array
*/
pop(): TItem {
throw new Error('Not implemented')
const item = this.items.pop()
if (item === undefined) avmError('The array is empty')
return item
}

copy(): DynamicArray<TItem> {
Expand Down
3 changes: 2 additions & 1 deletion packages/algo-ts/tsconfig.build.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"compilerOptions": {
"outDir": "./dist/",
"noEmit": false,
"declaration": true
"declaration": true,
"stripInternal": true
},
"include": ["src/**/*.ts"],
"exclude": ["src/**/*.spec.ts", "src/**/*.test.ts", "src/**/tests/**"]
Expand Down
1 change: 1 addition & 0 deletions src/awst/json-serialize-awst.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export class AwstSerializer extends SnakeCaseSerializer<RootNode[]> {
}
return {
...(super.serializerFunction(key, value) as object),
scope: undefined,
file: filePath,
}
}
Expand Down
22 changes: 0 additions & 22 deletions src/awst/models.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import type { LogicSig } from '../awst_build/models/contract-class'
import { ContractClass } from '../awst_build/models/contract-class'
import type { ContractClassPType } from '../awst_build/ptypes'
import type { Props } from '../typescript-helpers'
import { invariant } from '../util'
import { CustomKeyMap } from '../util/custom-key-map'
import type { SourceLocation } from './source-location'

export enum OnCompletionAction {
Expand Down Expand Up @@ -137,21 +133,3 @@ export enum TransactionKind {
afrz = 5,
appl = 6,
}

export class CompilationSet extends CustomKeyMap<ContractReference | LogicSigReference, ContractClass | LogicSig> {
constructor() {
super((x) => x.toString())
}

get compilationOutputSet() {
return Array.from(this.entries())
.filter(([, meta]) => (meta instanceof ContractClass ? !meta.isAbstract : false))
.map(([ref]) => ref)
}

getContractClass(cref: ContractReference) {
const maybeClass = this.get(cref)
invariant(maybeClass instanceof ContractClass, 'Contract reference must resolve to a contract class')
return maybeClass
}
}
5 changes: 5 additions & 0 deletions src/awst/source-location.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ export class SourceLocation {
this.column = props.column
this.endColumn = props.endColumn
this.scope = props.scope

// Exclude scope from enumerable properties so it doesn't end up being serialized
Object.defineProperty(this, 'scope', {
enumerable: false,
})
}

private static getStartAndEnd(node: ts.Node): { start: number; end: number } {
Expand Down
13 changes: 8 additions & 5 deletions src/awst/to-code-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ export class ToCodeVisitor
return [`goto ${statement.target}`]
}
visitIntersectionSliceExpression(expression: nodes.IntersectionSliceExpression): string {
throw new TodoError('Method not implemented.', { sourceLocation: expression.sourceLocation })
const args = [expression.beginIndex, expression.endIndex]
.flatMap((f) => (typeof f === 'bigint' ? f : (f?.accept(this) ?? [])))
.join(', ')
return `${expression.base.accept(this)}.slice(${args})`
}
visitBoxValueExpression(expression: nodes.BoxValueExpression): string {
if (expression.key instanceof nodes.BytesConstant) {
Expand Down Expand Up @@ -104,16 +107,16 @@ export class ToCodeVisitor
return `copy(${expression.value.accept(this)})`
}
visitArrayConcat(expression: nodes.ArrayConcat): string {
throw new TodoError('Method not implemented.', { sourceLocation: expression.sourceLocation })
return `${expression.left.accept(this)}.concat(${expression.right.accept(this)}`
}
visitArrayPop(expression: nodes.ArrayPop): string {
throw new TodoError('Method not implemented.', { sourceLocation: expression.sourceLocation })
return `${expression.base.accept(this)}.pop()`
}
visitArrayExtend(expression: nodes.ArrayExtend): string {
throw new TodoError('Method not implemented.', { sourceLocation: expression.sourceLocation })
return `${expression.base.accept(this)}.push(...${expression.other.accept(this)}`
}
visitARC4Decode(expression: nodes.ARC4Decode): string {
throw new TodoError('Method not implemented.', { sourceLocation: expression.sourceLocation })
return `ARC4_DECODE(${expression.value})`
}
visitIntrinsicCall(expression: nodes.IntrinsicCall): string {
const immediates = expression.immediates.length ? `<${expression.immediates.map((i) => i).join(', ')}>` : ''
Expand Down
10 changes: 5 additions & 5 deletions src/awst_build/context/awst-build-context.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ts from 'typescript'
import type { awst } from '../../awst'
import type { ContractReference, LogicSigReference } from '../../awst/models'
import { CompilationSet } from '../../awst/models'
import { nodeFactory } from '../../awst/node-factory'
import type { AppStorageDefinition, Constant } from '../../awst/nodes'
import { SourceLocation } from '../../awst/source-location'
Expand All @@ -10,7 +9,8 @@ import { logger } from '../../logger'
import { codeInvariant, invariant } from '../../util'
import type { AppStorageDeclaration } from '../contract-data'
import type { NodeBuilder } from '../eb'
import type { ContractClass, LogicSig } from '../models/contract-class'
import type { Index, LogicSig } from '../models'
import { CompilationSet } from '../models'
import type { ContractClassPType, PType } from '../ptypes'
import { typeRegistry } from '../type-registry'
import { TypeResolver } from '../type-resolver'
Expand Down Expand Up @@ -90,7 +90,7 @@ export interface AwstBuildContext {

getStorageDefinitionsForContract(contractType: ContractClassPType): AppStorageDefinition[]

addToCompilationSet(compilationTarget: ContractReference, contract: ContractClass): void
addToCompilationSet(compilationTarget: ContractReference, contract: Index): void
addToCompilationSet(compilationTarget: LogicSigReference, logicSig: LogicSig): void

get compilationSet(): CompilationSet
Expand Down Expand Up @@ -243,9 +243,9 @@ class AwstBuildContextImpl implements AwstBuildContext {
return Array.from(result.values())
}

addToCompilationSet(compilationTarget: ContractReference, contract: ContractClass): void
addToCompilationSet(compilationTarget: ContractReference, contract: Index): void
addToCompilationSet(compilationTarget: LogicSigReference, logicSig: LogicSig): void
addToCompilationSet(compilationTarget: ContractReference | LogicSigReference, contractOrSig: ContractClass | LogicSig) {
addToCompilationSet(compilationTarget: ContractReference | LogicSigReference, contractOrSig: Index | LogicSig) {
if (this.#compilationSet.has(compilationTarget)) {
logger.debug(undefined, `${compilationTarget.id} already exists in compilation set`)
return
Expand Down
4 changes: 2 additions & 2 deletions src/awst_build/contract-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { BoxProxyExpressionBuilder } from './eb/storage/box'
import { GlobalStateFunctionResultBuilder } from './eb/storage/global-state'
import { LocalStateFunctionResultBuilder } from './eb/storage/local-state'
import { requireInstanceBuilder } from './eb/util'
import { ContractClass } from './models/contract-class'
import { Index } from './models'
import type { ContractClassPType } from './ptypes'

export class ContractVisitor extends BaseVisitor implements Visitor<ClassElements, void> {
Expand Down Expand Up @@ -57,7 +57,7 @@ export class ContractVisitor extends BaseVisitor implements Visitor<ClassElement
logger.error(this._approvalProgram.sourceLocation, 'ARC4 contracts cannot define their own approval methods.')
}

const contract = new ContractClass({
const contract = new Index({
type: this._contractPType,
propertyInitialization: this._propertyInitialization,
isAbstract: isAbstract,
Expand Down
30 changes: 22 additions & 8 deletions src/awst_build/eb/arc4/arrays.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { nodeFactory } from '../../../awst/node-factory'
import type { Expression } from '../../../awst/nodes'
import { StringConstant } from '../../../awst/nodes'
import { IntegerConstant, StringConstant } from '../../../awst/nodes'
import type { SourceLocation } from '../../../awst/source-location'
import { wtypes } from '../../../awst/wtypes'
import { Constants } from '../../../constants'
Expand Down Expand Up @@ -166,13 +166,24 @@ export class AddressConstructorBuilder extends NodeBuilder {
export abstract class ArrayExpressionBuilder<
TArrayType extends DynamicArrayType | StaticArrayType,
> extends Arc4EncodedBaseExpressionBuilder<TArrayType> {
iterate(sourceLocation: SourceLocation): Expression {
iterate(): Expression {
return this.resolve()
}

indexAccess(index: InstanceBuilder, sourceLocation: SourceLocation): NodeBuilder {
// TODO
return super.indexAccess(index, sourceLocation)
const indexExpr = requireExpressionOfType(index, uint64PType)
if (indexExpr instanceof IntegerConstant && this.ptype instanceof StaticArrayType && indexExpr.value >= this.ptype.arraySize) {
logger.error(index.sourceLocation, 'Index access out of bounds')
}
return instanceEb(
nodeFactory.indexExpression({
base: this.resolve(),
sourceLocation: sourceLocation,
index: indexExpr,
wtype: this.ptype.elementType.wtype,
}),
this.ptype.elementType,
)
}

memberAccess(name: string, sourceLocation: SourceLocation): NodeBuilder {
Expand All @@ -189,8 +200,11 @@ export abstract class ArrayExpressionBuilder<
return new EntriesFunctionBuilder(this)
case 'copy':
return new CopyFunctionBuilder(this)
case 'slice':
return new SliceFunctionBuilder(this.resolve(), this.ptype)
case 'slice': {
const sliceResult =
this.ptype instanceof StaticArrayType ? new DynamicArrayType({ elementType: this.ptype.elementType }) : this.ptype
return new SliceFunctionBuilder(this.resolve(), sliceResult)
}
}
return super.memberAccess(name, sourceLocation)
}
Expand Down Expand Up @@ -222,7 +236,7 @@ class EntriesFunctionBuilder extends FunctionBuilder {
const iteratorType = IterableIteratorType.parameterise([new TuplePType({ items: [uint64PType, this.arrayBuilder.ptype.elementType] })])
return new IterableIteratorExpressionBuilder(
nodeFactory.enumeration({
expr: this.arrayBuilder.iterate(sourceLocation),
expr: this.arrayBuilder.iterate(),
sourceLocation,
wtype: iteratorType.wtype,
}),
Expand Down Expand Up @@ -301,7 +315,7 @@ export class ArrayPushFunctionBuilder extends FunctionBuilder {
} = parseFunctionArgs({
args,
typeArgs,
funcName: 'at',
funcName: 'push',
callLocation: sourceLocation,
genericTypeArgs: 0,
argSpec: (a) => [a.required(elementType), ...args.slice(1).map(() => a.required(elementType))],
Expand Down
14 changes: 8 additions & 6 deletions src/awst_build/eb/arc4/uint-n-constructor-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ export class ByteConstructorBuilder extends NodeBuilder {
function newUintN(initialValueBuilder: InstanceBuilder | undefined, ptype: UintNType, sourceLocation: SourceLocation) {
if (initialValueBuilder === undefined) {
return new UintNExpressionBuilder(
nodeFactory.bytesConstant({
value: new Uint8Array([0]),
wtype: ptype.wtypeOrThrow,
nodeFactory.integerConstant({
value: 0n,
tealAlias: null,
wtype: ptype.wtype,
sourceLocation: sourceLocation,
}),
ptype,
Expand All @@ -80,9 +81,10 @@ function newUintN(initialValueBuilder: InstanceBuilder | undefined, ptype: UintN
if (initialValue instanceof IntegerConstant) {
codeInvariant(isValidLiteralForPType(initialValue.value, ptype), `${initialValue.value} cannot be converted to ${ptype}`)
return new UintNExpressionBuilder(
nodeFactory.bytesConstant({
value: bigIntToUint8Array(initialValue.value),
wtype: ptype.wtypeOrThrow,
nodeFactory.integerConstant({
value: initialValue.value,
wtype: ptype.wtype,
tealAlias: null,
sourceLocation: sourceLocation,
}),
ptype,
Expand Down
2 changes: 1 addition & 1 deletion src/awst_build/eb/shared/at-function-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export class AtFunctionBuilder extends FunctionBuilder {
if (typeof this.exprLength === 'bigint') {
let indexValue = indexParam < 0 ? this.exprLength + indexParam : indexParam
if (indexValue < 0n || indexValue >= this.exprLength) {
logger.warn(index.sourceLocation, 'Index access out of bounds')
logger.error(index.sourceLocation, 'Index access out of bounds')
indexValue = 0n
}
indexExpr = nodeFactory.uInt64Constant({
Expand Down
2 changes: 1 addition & 1 deletion src/awst_build/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { jsonSerializeAwst } from '../awst/json-serialize-awst'
import type { CompilationSet } from '../awst/models'
import type { AWST } from '../awst/nodes'
import { SourceLocation } from '../awst/source-location'
import { ToCodeVisitor } from '../awst/to-code-visitor'
Expand All @@ -10,6 +9,7 @@ import type { CreateProgramResult } from '../parser'
import { ArtifactKind, writeArtifact } from '../write-artifact'
import { buildContextForProgram } from './context/awst-build-context'
import { buildLibAwst } from './lib'
import type { CompilationSet } from './models'
import { SourceFileVisitor } from './source-file-visitor'

export function buildAwst({ program, sourceFiles }: CreateProgramResult, options: CompileOptions): [AWST[], CompilationSet] {
Expand Down
Loading

0 comments on commit 818f1aa

Please sign in to comment.