Skip to content
This repository has been archived by the owner on Sep 17, 2022. It is now read-only.

Start integration with TensorBoard #197

Merged
merged 32 commits into from
Feb 5, 2019
Merged

Conversation

caisq
Copy link
Contributor

@caisq caisq commented Jan 29, 2019

  • Add node.js-specific representation of int64-type tensors
  • Add node.js-specific support for resource-type tensors
  • Add the following node.js backend-specific op kernel binding:
    • summaryWriter
    • createSummaryFileWriter
    • writeScalarSummary
    • flushSummaryWriter
  • Add the public Node.js JavaScript/TypeScrpit api in tensorboard.ts, under the API namespace tf.node.*

Usage example (TypeScript):

import * as tf from '@tensorflow/tfjs-node';

const summaryWriter = tf.node.summaryFileWriter('/tmp/tfjs_tb_logdir');
for (let i = -1e3; i < 1e3; i += 10) {
  summaryWriter.scalar('loss', i * i * i * i, i);
  summaryWriter.scalar('acc', -i * i * i * i, i);
}

Open tensorboard:

pip install tensorboard  # Unless you've already installed it.
tensorboard --logdir /tmp/tfjs_tb_logdir

Screenshot of result:
image

Towards tensorflow/tfjs#686

This change is Reviewable

@caisq caisq changed the title Start integration with TensorBoard [WIP; DO NOT REVIEW YET] Start integration with TensorBoard Jan 29, 2019
@caisq caisq changed the title [WIP; DO NOT REVIEW YET] Start integration with TensorBoard Start integration with TensorBoard Jan 30, 2019
@caisq
Copy link
Contributor Author

caisq commented Jan 30, 2019

cc @nfelt @stephanwlee

@caisq
Copy link
Contributor Author

caisq commented Jan 30, 2019

I will add unit tests once we agree on the general approach.

Copy link
Contributor

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @dsmilkov, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 381 at r2 (raw file):

const char *limit = static_cast<const char *>(tensor_data) + byte_length;

This line looks unused (and wrong). TF_TensorByteSize should return the actual byte size of what TF_TensorData returns.


binding/tfjs_backend.cc, line 393 at r2 (raw file):

const char *data = static_cast<const char *>(tensor_data);

Just do this all at once:

const char *data = static_cast<const char *>(TF_TensorData(tensor.tensor));

binding/tfjs_backend.cc, line 399 at r2 (raw file):

nstatus = napi_create_array_with_length(env, byte_length, result);

ENSURE_NAPI_OK(env, nstatus, nstatus);


binding/tfjs_backend.cc, line 402 at r2 (raw file):

 *array_buffer_data;

init to nullptr:

void *array_buffer_data = nullptr;

binding/tfjs_backend.cc, line 411 at r2 (raw file):

napi_uint8_array

We should call out in docs that this method will only return uint8 arrays.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @dsmilkov, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 381 at r2 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
const char *limit = static_cast<const char *>(tensor_data) + byte_length;

This line looks unused (and wrong). TF_TensorByteSize should return the actual byte size of what TF_TensorData returns.

Removed this line.


binding/tfjs_backend.cc, line 393 at r2 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
const char *data = static_cast<const char *>(tensor_data);

Just do this all at once:

const char *data = static_cast<const char *>(TF_TensorData(tensor.tensor));

Done.


binding/tfjs_backend.cc, line 399 at r2 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
nstatus = napi_create_array_with_length(env, byte_length, result);

ENSURE_NAPI_OK(env, nstatus, nstatus);

Done.


binding/tfjs_backend.cc, line 402 at r2 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
 *array_buffer_data;

init to nullptr:

void *array_buffer_data = nullptr;

Done.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 411 at r2 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
napi_uint8_array

We should call out in docs that this method will only return uint8 arrays.

Done.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nkreeger Thank you for the review so far! I've addressed all your comments and added unit tests.

PTAL.

Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nkreeger, and @nsthorat)

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 10 files at r1, 3 of 8 files at r3.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @dsmilkov, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 87 at r3 (raw file):

    // Currently, int64-type Tensors are represented as Int32Arrays. So the
    // logic for comparing the byte size of the typed-array representation and
    // the byte size of the tensor dtype needs to be special-cased for int64.

since you asserted that the value is in the int32 range, no need for 2 numbers to represent a single number, which will simplify the logic here and below


src/int64_tensors.ts, line 62 at r3 (raw file):

    // We use two int32 elements to represent a int64 value. This assumes
    // little endian, which is checked above.
    const highPart = Math.floor(value / INT32_MAX);

since you asserted that the value is in the int32 range, why use two numbers to represent. Isn't a single number sufficient?


src/node.ts, line 24 at r3 (raw file):

import {summaryFileWriter} from './tensorboard';

export const node = {summaryFileWriter};

Why create an artificial namespace "node. The user will already import tfjs-node which means effectively that user lives in the node namespace. There is no need for namespaces here since we don't target the browser, so things are not sticked to a global tf object. cc @nsthorat WDYT?


src/tensorboard.ts, line 75 at r3 (raw file):

 * const tf = require('@tensorflow/tfjs-node');
 *
 * const summaryWriter = tf.node.summaryFileWriter('/tmp/tfjs_tb_logdir');

This should be tf.summaryFileWriter (no tf.node.* since the user is already importing tfjs-node which is effectively tf.node. No need for namespaces here since we don't target the browser, thus things are not sticked to a global tf object.


src/tensorboard.ts, line 95 at r3 (raw file):

  util.assert(
      logdir != null && typeof logdir === 'string' && logdir.length > 0,
      `logdir is null, undefined, not a string, or an empty string`);

Change the user message to Invalid logdir "${logdir}". Please provide the logging directory as a string.


src/tensorboard_test.ts, line 27 at r3 (raw file):

const tmp = require('tmp');

import {summaryFileWriter} from './tensorboard';

instead of this, do import * as tfn from './index' and use the public API to unit test. This helps us understand how namespaces interact.

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/node.ts, line 24 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Why create an artificial namespace "node. The user will already import tfjs-node which means effectively that user lives in the node namespace. There is no need for namespaces here since we don't target the browser, so things are not sticked to a global tf object. cc @nsthorat WDYT?

I'm torn because in some way it makes it super clear why methods will work only in node. But in another way more users are using import {methodINeed} from 'module' instead of import *. For former is also recommended by all bundlers in order to help tree-shake the codebase. That means that with an artificial 'node' namespace, users can't import only summaryWriter since you can't do import {node.summaryWriter} from 'tfjs-node'

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/node.ts, line 24 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

I'm torn because in some way it makes it super clear why methods will work only in node. But in another way more users are using import {methodINeed} from 'module' instead of import *. For former is also recommended by all bundlers in order to help tree-shake the codebase. That means that with an artificial 'node' namespace, users can't import only summaryWriter since you can't do import {node.summaryWriter} from 'tfjs-node'

Nevermind. Nikhil had a great point. Tree-shaking doesn't matter in node so introducing a node namespace sounds good.


src/tensorboard.ts, line 75 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

This should be tf.summaryFileWriter (no tf.node.* since the user is already importing tfjs-node which is effectively tf.node. No need for namespaces here since we don't target the browser, thus things are not sticked to a global tf object.

Ignore this comment. See above reply

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 87 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

since you asserted that the value is in the int32 range, no need for 2 numbers to represent a single number, which will simplify the logic here and below

It is necessary to represent the int64 as two numbers for the following two reasons

  1. When doing memcpy, the DT_INT64-dtype tensor in C++ expects eight bytes. Storing it as two int32 values is convenient for the memcpy'ing.
  2. The sign is in the high part (i.e., the part that's not the least significant digits corresponding to the int32 value). So in order to represent negative int32 values correctly, we need to use the next four bytes.

src/int64_tensors.ts, line 62 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

since you asserted that the value is in the int32 range, why use two numbers to represent. Isn't a single number sufficient?

See my comment above. Doing it like this works. But I agree that it is confusing. This line really takes care of only the sign bit. I've revised it to avoid potential confusion.


src/node.ts, line 24 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Nevermind. Nikhil had a great point. Tree-shaking doesn't matter in node so introducing a node namespace sounds good.

Acknowledged. Thanks for the discussion.


src/tensorboard.ts, line 75 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Ignore this comment. See above reply

Ack.


src/tensorboard.ts, line 95 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Change the user message to Invalid logdir "${logdir}". Please provide the logging directory as a string.

Done.


src/tensorboard_test.ts, line 27 at r3 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

instead of this, do import * as tfn from './index' and use the public API to unit test. This helps us understand how namespaces interact.

Done.

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fast turnaround. LGTM!!

Reviewed 3 of 10 files at r1, 3 of 8 files at r3, 3 of 3 files at r4.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


binding/tfjs_backend.cc, line 87 at r3 (raw file):

Previously, caisq (Shanqing Cai) wrote…

It is necessary to represent the int64 as two numbers for the following two reasons

  1. When doing memcpy, the DT_INT64-dtype tensor in C++ expects eight bytes. Storing it as two int32 values is convenient for the memcpy'ing.
  2. The sign is in the high part (i.e., the part that's not the least significant digits corresponding to the int32 value). So in order to represent negative int32 values correctly, we need to use the next four bytes.

SGTM since it makes memcpy convenient. Thanks for explaining.


src/int64_tensors.ts, line 62 at r3 (raw file):

Previously, caisq (Shanqing Cai) wrote…

See my comment above. Doing it like this works. But I agree that it is confusing. This line really takes care of only the sign bit. I've revised it to avoid potential confusion.

SGTM.

this.tensorMap.set((tensors[i] as Tensor).dataId, info);
}
ids.push(info.id);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check

} else if (tensors[i] instanceof Int64Scalar) {
...
} else {
  throw new Error(`Invalid Tensor type: ${typeof tensors[i]}`);
}

*
* This class is introduced as a workaround.
*/
export class Int64Scalar {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused with this class still.

It represents an int64 but JS can't handle those types today. It tries to parse the high/low parts of the number and validate endianness? This is mostly confusing because I think we just wanted to check that the numbers are within INT32_MIN and INT32_MAX. Why do we need all the various other stuff - seems somewhat expensive for not doing anything.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/int64_tensors.ts, line 35 at r4 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…

I'm a little confused with this class still.

It represents an int64 but JS can't handle those types today. It tries to parse the high/low parts of the number and validate endianness? This is mostly confusing because I think we just wanted to check that the numbers are within INT32_MIN and INT32_MAX. Why do we need all the various other stuff - seems somewhat expensive for not doing anything.

The reason why we need to all the other stuff is because we need to worry about negative values. Even though the binary representation of a positive int32 number is exactly the low part of a positive int64 number of the same value, this is not true if we are talking about negative values. The negative sign is represented in the high part (i.e., the high 4 bytes). We can't just take the binary representation of the negative int32 value and pad it with zeros in the high part because that'll represent a different value in the int64 format.

The logic here takes care of the sign in a way that assumes little endian. This is why we check endianness at the beginning. Most machines support little endian nowadays, if we focus on the kind of devices that people will use tfjs-node on.

Let me know if that makes sense.

Copy link
Contributor

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/int64_tensors.ts, line 35 at r4 (raw file):

Previously, caisq (Shanqing Cai) wrote…

The reason why we need to all the other stuff is because we need to worry about negative values. Even though the binary representation of a positive int32 number is exactly the low part of a positive int64 number of the same value, this is not true if we are talking about negative values. The negative sign is represented in the high part (i.e., the high 4 bytes). We can't just take the binary representation of the negative int32 value and pad it with zeros in the high part because that'll represent a different value in the int64 format.

The logic here takes care of the sign in a way that assumes little endian. This is why we check endianness at the beginning. Most machines support little endian nowadays, if we focus on the kind of devices that people will use tfjs-node on.

Let me know if that makes sense.

OK yes this makes sense. Can you add some of this in-line as documentation inside of the constructor? It is not clear why you need to break those out unless you have domain knowledge.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/int64_tensors.ts, line 35 at r4 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…

OK yes this makes sense. Can you add some of this in-line as documentation inside of the constructor? It is not clear why you need to break those out unless you have domain knowledge.

Done.

Copy link
Contributor Author

@caisq caisq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)


src/nodejs_kernel_backend.ts, line 126 at r4 (raw file):

else if
Done.

Copy link
Contributor

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @caisq, @nkreeger, and @nsthorat)

@caisq caisq merged commit bf99863 into tensorflow:master Feb 5, 2019
@caisq caisq deleted the tensorboard branch February 5, 2019 20:41
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants