forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathderivative_attr_parse.swift
100 lines (83 loc) · 3.01 KB
/
derivative_attr_parse.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// RUN: %target-swift-frontend -parse -verify %s
/// Good
@derivative(of: sin, wrt: x) // ok
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
@derivative(of: add, wrt: (x, y)) // ok
func vjpAdd(x: Float, y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +) // ok
static func vjpAdd(x: Self, y: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x + y, { v in (v, v) })
}
}
@derivative(of: foo) // ok
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
/// Bad
// expected-error @+3 {{expected an original function name}}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: 3)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: wrt, foo)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected a colon ':' after 'wrt'}}
@derivative(of: foo, wrt)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, blah, wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x, blah)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
func localDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}
// Test deprecated `@differentiating` attribute.
// expected-warning @+1 {{'@differentiating' attribute is deprecated; use '@derivative(of:)' instead}}
@differentiating(sin)
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}