"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [500/500 01:17, Epoch 1/2]\n",
"
\n",
" \n",
" \n",
" \n",
" Step | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Accuracy | \n",
"
\n",
" \n",
" \n",
" \n",
" 50 | \n",
" 0.465300 | \n",
" 0.060249 | \n",
" 0.981166 | \n",
"
\n",
" \n",
" 100 | \n",
" 0.053500 | \n",
" 0.037531 | \n",
" 0.989238 | \n",
"
\n",
" \n",
" 150 | \n",
" 0.070400 | \n",
" 0.031749 | \n",
" 0.991928 | \n",
"
\n",
" \n",
" 200 | \n",
" 0.095200 | \n",
" 0.026542 | \n",
" 0.991928 | \n",
"
\n",
" \n",
" 250 | \n",
" 0.029900 | \n",
" 0.023062 | \n",
" 0.993722 | \n",
"
\n",
" \n",
" 300 | \n",
" 0.032900 | \n",
" 0.023522 | \n",
" 0.992825 | \n",
"
\n",
" \n",
" 350 | \n",
" 0.026700 | \n",
" 0.021454 | \n",
" 0.993722 | \n",
"
\n",
" \n",
" 400 | \n",
" 0.018700 | \n",
" 0.020306 | \n",
" 0.994619 | \n",
"
\n",
" \n",
" 450 | \n",
" 0.023600 | \n",
" 0.020214 | \n",
" 0.995516 | \n",
"
\n",
" \n",
" 500 | \n",
" 0.007300 | \n",
" 0.020931 | \n",
" 0.994619 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Checkpoint destination directory spam_classifier/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=500, training_loss=0.08234859538078308, metrics={'train_runtime': 91.3384, 'train_samples_per_second': 87.586, 'train_steps_per_second': 5.474, 'total_flos': 129662297999304.0, 'train_loss': 0.08234859538078308, 'epoch': 1.79})"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Как видим точность получилась очень хорошей: 0.995!\n",
"\n",
"По напечатанной выше ссылке можно посмотреть графики изменения лосса и точности в процессе обучения.\n",
"\n",
"Также посчитаем предсказания берта и сохраним их, далее они нам пригодятся."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9946188340807175\n"
]
}
],
"source": [
"# Применяем модель к тестовому датасету, получаем логиты\n",
"bert_logits = trainer.predict(tokenized_test).predictions\n",
"\n",
"# Для каждого элемента берем индекс максимального логита, это и есть наш класс\n",
"bert_predictions = np.argmax(bert_logits, axis=1)\n",
"\n",
"# Проверяем точность\n",
"np.mean(data_test.label == bert_predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.4 Сравнение с наивным байесовским классификатором\n",
"\n",
"Теперь разберемся, а точно ли качество выросло? Чтобы убедиться, что это не просто шум, надо провести статистический тест. Для начала вычисляем вектора ошибок."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lMZOcjQD4IUE",
"outputId": "6feccc13-73e8-4ace-acd4-f2c2504901a9"
},
"outputs": [
{
"data": {
"text/plain": [
"(0.9739910313901345, 0.9946188340807175)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"is_error_multinomial_nb = (data_test.label == multinomial_predictions).astype(\n",
" np.int32\n",
")\n",
"is_error_bert = (data_test.label == bert_predictions).astype(np.int32)\n",
"\n",
"# Убедимся, что всё правильно посчитали и точность не изменилась\n",
"np.mean(is_error_multinomial_nb), np.mean(is_error_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь t-тест для связанных выборок, он проверяет гипотезу о равенстве средних выборок."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T-Wim4ok46Mv",
"outputId": "e967949d-9f6d-40b9-a018-22fe083cd60d"
},
"outputs": [
{
"data": {
"text/plain": [
"TtestResult(statistic=-4.463758730826827, pvalue=8.871553503645221e-06, df=1114)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ttest_rel(is_error_multinomial_nb, is_error_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Видно, что `pvalue` $= 8.8 \\cdot 10^{-6} < 0.05$, а значит мы имеем статистическое доказательство того, что берт справляется с классификацией спама значимо лучше наивного байесовского классификатора.\n",
"\n",
"*Подробнее с теорией и практикой проверки статистических гипотез вы познакомитесь на 3 курсе*."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"AKg26mO_0j-D"
],
"gpuType": "T4",
"provenance": []
},
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}