Skip to content

Commit

Permalink
Merge pull request #24710 from rajasekharporeddy:typos
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693412112
  • Loading branch information
Google-ML-Automation committed Nov 5, 2024
2 parents 497a5a3 + a80d027 commit c1af808
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/export/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ for which the code was exported.
You can specify explicitly for what platforms the code should be exported.
This allows you to specify a different accelerator than you have
available at export time,
and it even allows you to specify multi-platform lexport to
and it even allows you to specify multi-platform export to
obtain an `Exported` object that can be compiled and executed
on multiple platforms.

Expand Down Expand Up @@ -293,7 +293,7 @@ resulting module size should be only marginally larger than the
size of a module with default export.
As an extreme case, when serializing a module without any
primitives with platform-specific lowering, you will get
the same StableHLO as for the single-plaform export.
the same StableHLO as for the single-platform export.

```python
>>> import jax
Expand Down
12 changes: 6 additions & 6 deletions docs/export/shape_poly.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ following example:
```

Note that such functions are still re-compiled on demand for
each concrete input shapes they are invoked on. Only the
each concrete input shape they are invoked on. Only the
tracing and the lowering are saved.

The {func}`jax.export.symbolic_shape` is used in the above
Expand Down Expand Up @@ -98,7 +98,7 @@ A few examples of shape specifications:
arguments. Note that the same specification would work if the first
argument is a pytree of 3D arrays, all with the same leading dimension
but possibly with different trailing dimensions.
The value `None` for the second arugment means that the argument
The value `None` for the second argument means that the argument
is not symbolic. Equivalently, one can use `...`.

* `("(batch, ...)", "(batch,)")` specifies that the two arguments
Expand Down Expand Up @@ -256,7 +256,7 @@ as follows:
integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`,
`a >= b`, `a - b >= 0` are inconclusive and result in an exception.

In cases where a comparison operation cannot be resolve to a boolean,
In cases where a comparison operation cannot be resolved to a boolean,
we raise {class}`InconclusiveDimensionOperation`. E.g.,

```python
Expand Down Expand Up @@ -351,7 +351,7 @@ symbolic constraints:
is encountered, it is rewritten to the expression on
the right.
E.g., `floordiv(a, b) == c` works by replacing all
occurences of `floordiv(a, b)` with `c`.
occurrences of `floordiv(a, b)` with `c`.
Equality constraints must not contain addition or
subtraction at the top-level on the left-hand-side. Examples of
valid left-hand-sides are `a * b`, or `4 * a`, or
Expand Down Expand Up @@ -498,11 +498,11 @@ This works well for most use cases, and
it mirrors the calling convention of JIT functions.

Sometimes you may want to export a function parameterized
by an integer values that determines some shapes in the program.
by an integer value that determines some shapes in the program.
For example, we may
want to export the function `my_top_k` defined below,
parameterized by the
value of `k`, which determined the shape of the result.
value of `k`, which determines the shape of the result.
The following attempt will lead to an error since the dimension
variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`:

Expand Down
10 changes: 5 additions & 5 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ params_vars = tf.nest.map_structure(tf.Variable, params)
prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs)

my_model = tf.Module()
# Tell the model saver what are the variables.
# Tell the model saver what the variables are.
my_model._variables = tf.nest.flatten(params_vars)
my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False)
tf.saved_model.save(my_model)
Expand Down Expand Up @@ -760,7 +760,7 @@ symbolic constraints:
We plan to improve somewhat this area in the future.
* Equality constraints are treated as normalization rules.
E.g., `floordiv(a, b) = c` works by replacing all
occurences of the left-hand-side with the right-hand-side.
occurrences of the left-hand-side with the right-hand-side.
You can only have equality constraints where the left-hand-side
is a multiplication of factors, e.g, `a * b`, or `4 * a`, or
`floordiv(a, b)`. Thus, the left-hand-side cannot contain
Expand Down Expand Up @@ -1048,7 +1048,7 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64))
```

When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
When the `JAX_ENABLE_X64` flag is set, JAX uses 64-bit types
for Python scalars and respects the explicit 64-bit types:

```python
Expand Down Expand Up @@ -1245,7 +1245,7 @@ Applies to both native and non-native serialization.
trackable classes during attribute assignment.
Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper
classes.
In most situation, these Wrapper classes work exactly as the standard
In most situations, these Wrapper classes work exactly as the standard
Python data types. However, the low-level pytree data structures are different
and this can lead to errors.

Expand Down Expand Up @@ -1499,7 +1499,7 @@ during lowering we try to generate one TensorFlow op for one JAX primitive.
We expect that the lowering that XLA does is similar to that done by JAX
before conversion. (This is a hypothesis, we have not yet verified it extensively.)

There is one know case when the performance of the lowered code will be different.
There is one known case when the performance of the lowered code will be different.
JAX programs use a [stateless
deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md)
and it has an internal JAX primitive for it.
Expand Down

0 comments on commit c1af808

Please sign in to comment.