LINUX.ORG.RU

TensorFlow и жалкие попытки его использования

 , , ,


1

2

У меня есть два вопроса:

1) Что творится с этим долбанным фреймворком? Где нормльные книги, где нормальные статить, где нормальные доки? В официальных доках черт ногу сломит (это отличительная черта гугла, например, в доках по апи ютуба и гугл плюса тоже самое), при попытки нагуглить какую-то банальную вещь сразу натыкаешься на version-hell. Я на столько отчаялся в поисках нормальной обучающей инфы, что даже купил вот эту книгу, но меня ждал сюрприз - книга предполагает, что я знаю TensorFlow, лол.

2) У меня есть простая, на первый взгляд, задачка. Мне нужно сохранить обученную модель, загрузить ее и, собственно, использовать.

Вот так я создаю и тренирую модель:

def build_model(learning_rate=0.1):
    tf.reset_default_graph()
    
    net = tflearn.input_data([None, VOCAB_SIZE])
    net = tflearn.fully_connected(net, 125, activation='ReLU')
    net = tflearn.fully_connected(net, 25, activation='ReLU')
    net = tflearn.fully_connected(net, 2, activation='softmax')
    regression = tflearn.regression(
        net, 
        optimizer='sgd', 
        learning_rate=learning_rate, 
        loss='categorical_crossentropy')
    
    model = tflearn.DNN(net)
    return model

Создаем модель:

model = build_model(learning_rate=0.75)

Тренируем:

model.fit(
    X_train, 
    y_train, 
    validation_set=0.1, 
    show_metric=True, 
    batch_size=128, 
    n_epoch=30)

Далее, я пытаюсь сохранить модель:

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './sentiment-model')

Получаю вот такие варнинги:

WARNING:tensorflow:Error encountered when serializing data_preprocessing.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'NoneType' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing data_augmentation.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'NoneType' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing summary_tags.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'dict' object has no attribute 'name'

Но файлики таки создаются:

checkpoint
sentiment-model.data-00000-of-00001		
sentiment-model.index			
sentiment-model.meta

И вот что делать дальше, я не понимаю абсолютно. По идее, нужно выполнить код, типа этого:

sess = tf.Session()
new_saver = tf.train.import_meta_graph('sentiment-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

Но тут я получаю ошибку:

"The name 'SGD' refers to an Operation not in the graph."
Видимо, она вылазит из за optimizer='sgd', но почему?

Deleted

Ответ на: комментарий от anonymous

Рыжий толстый с широкой костью. Похож на кота из мультика про попугая, который «таити-таити, нас и здесь неплохо кормят».

DELIRIUM ☆☆☆☆☆
()

TFLearn и tensorflow — это разные штуки.

Если пользуешься TFLearn, то я бы советовал и сохранять через него model.save/model.load.

Если писать на чистом tensorflow, то нужно сначала создать сессию, и тренировать уже в сессии. И в этой ситуации у меня сохранение работало. Видимо, мешать код tflearn и tensorflow лучше только, если хорошо понимаешь, как tflearn устроен внутри.

Davidov ★★★★
()
Ответ на: комментарий от Davidov

Ага, спасибо за наводку, щас попробую.

Deleted
()
Ответ на: комментарий от Davidov

Что касается tensorflow, то примеры тут. Но про несовместимость версий и плохую документацию — чистая правда. Например, они поменяли способ создания MultiRNNCell, и теперь половина существующего кода работает, но не так как ожидается.

Для высокого уровня, очень рекомендую keras. Доки неплохие, есть slack, где можно задать вопрос.

Davidov ★★★★
()
Ответ на: комментарий от Davidov

TFLearn и tensorflow — это разные штуки.

Это самая полезная инфа, которую я узнал за неделю.

В общем, с сохранением и загрузкой вопрос решен. Сейчас, на сколько я понимаю, мне нужно разобраться как поднять gRPC сервер, чтобы юзать сеть в своих проектах.

Deleted
()
Последнее исправление: Bizun (всего исправлений: 1)

Родина им дала PyTorch с человеческим API - пользуйся! Не хочу, хочу жрать гуглокактус.

anonymous
()
Вы не можете добавлять комментарии в эту тему. Тема перемещена в архив.