forked from ml-explore/mlx-swift-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathArguments.swift
100 lines (79 loc) · 2.51 KB
/
Arguments.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import MLX
#if swift(>=5.10)
/// Extension to allow URL command line arguments.
extension URL: @retroactive ExpressibleByArgument {
public init?(argument: String) {
if argument.contains("://") {
self.init(string: argument)
} else {
self.init(filePath: argument)
}
}
}
#else
/// Extension to allow URL command line arguments.
extension URL: ExpressibleByArgument {
public init?(argument: String) {
if argument.contains("://") {
self.init(string: argument)
} else {
self.init(filePath: argument)
}
}
}
#endif
/// Argument package for adjusting and reporting memory use.
struct MemoryArguments: ParsableArguments, Sendable {
@Flag(name: .long, help: "Show memory stats")
var memoryStats = false
@Option(name: .long, help: "Maximum cache size in M")
var cacheSize = 1024
@Option(name: .long, help: "Maximum memory size in M")
var memorySize: Int?
var startMemory: GPU.Snapshot?
mutating func start<L>(_ load: () async throws -> L) async throws -> L {
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
if let memorySize {
GPU.set(memoryLimit: memorySize * 1024 * 1024)
}
let result = try await load()
startMemory = GPU.snapshot()
return result
}
mutating func start() {
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
if let memorySize {
GPU.set(memoryLimit: memorySize * 1024 * 1024)
}
startMemory = GPU.snapshot()
}
func reportCurrent() {
if memoryStats {
let memory = GPU.snapshot()
print(memory.description)
}
}
func reportMemoryStatistics() {
if memoryStats, let startMemory {
let endMemory = GPU.snapshot()
print("=======")
print("Memory size: \(GPU.memoryLimit / 1024)K")
print("Cache size: \(GPU.cacheLimit / 1024)K")
print("")
print("=======")
print("Starting memory")
print(startMemory.description)
print("")
print("=======")
print("Ending memory")
print(endMemory.description)
print("")
print("=======")
print("Growth")
print(startMemory.delta(endMemory).description)
}
}
}