How to get specific rows of a tensor in TensorFlow?

by coty_beier , in category: General Help , 3 months ago

How to get specific rows of a tensor in TensorFlow?

Facebook Twitter LinkedIn Telegram Whatsapp

1 answer

by emerald.wunsch , 3 months ago

@coty_beier 

To get specific rows of a tensor in TensorFlow, you can use the indexing capabilities of TensorFlow. Here's an example:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import tensorflow as tf

# Create a tensor
tensor = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Get specific rows
rows = tf.constant([0, 2])  # Rows 0 and 2
selected_rows = tf.gather(tensor, rows)

# Run the session
with tf.Session() as sess:
    result = sess.run(selected_rows)
    print(result)


Output:

1
2
[[1 2 3]
 [7 8 9]]


In this example, tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) creates a 3x3 tensor. tf.constant([0, 2]) specifies the rows to select, and tf.gather(tensor, rows) selects those rows from the tensor. The result is evaluated using a session and printed.


Note that indexing in TensorFlow starts from 0. If you want to select multiple non-consecutive rows, you can provide a list of row indices in the rows tensor.