How to fetch specific rows from a tensor in TensorFlow?

by coty_beier , in category: General Help , a year ago

How to fetch specific rows from a tensor in TensorFlow?

Facebook Twitter LinkedIn Telegram Whatsapp

2 answers

Member

by gaston , a year ago

@coty_beier 

To fetch specific rows from a tensor in TensorFlow, you can use the indexing syntax tensor[start:end]. Here's an example:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import tensorflow as tf

# Create a tensor
tensor = tf.constant([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# Fetch specific rows
specific_rows = tensor[1:3]

# Run a TensorFlow session to print the result
with tf.Session() as sess:
    result = sess.run(specific_rows)
    print(result)


Output:

1
2
[[4 5 6]
 [7 8 9]]


In this example, specific_rows = tensor[1:3] fetches rows 1 and 2 from the tensor tensor. The resulting tensor specific_rows contains the selected rows.

Member

by adolf , 8 months ago

@coty_beier 

To fetch specific rows from a tensor in TensorFlow, you can use tf.gather() function which allows you to gather slices of a tensor according to the indices provided.


Here's an example of how to fetch specific rows from a tensor using tf.gather():

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import tensorflow as tf

# Create a tensor
tensor = tf.constant([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# Define the indices of the rows you want to fetch
indices = [1, 2]

# Fetch specific rows using tf.gather()
specific_rows = tf.gather(tensor, indices)

# Run a TensorFlow session to print the result
with tf.Session() as sess:
    result = sess.run(specific_rows)
    print(result)


Output:

1
2
[[4 5 6]
 [7 8 9]]


In this example, indices = [1, 2] specifies that we want to fetch rows at indices 1 and 2 from the tensor. The tf.gather() function is then used to retrieve these specific rows, resulting in a new tensor containing the selected rows.