aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSlendi <slendi@socopon.com>2023-11-05 00:34:30 +0200
committerSlendi <slendi@socopon.com>2023-11-05 00:34:30 +0200
commitb8acb45e6837dbdd22b96a48110f4137236618d5 (patch)
tree0d88ac924f2bb0bf3ba2aa157c51954a0856df5b
Initial commit.
Signed-off-by: Slendi <slendi@socopon.com>
-rw-r--r--.gitignore5
-rwxr-xr-xcreate_dataset.sh19
-rw-r--r--dataset_maker.py48
-rwxr-xr-xinteractive.py15
-rwxr-xr-xsplit_file.py43
-rwxr-xr-xtrain.py145
6 files changed, 275 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ed07e02
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+venv
+train.json
+dataset
+final_model
+
diff --git a/create_dataset.sh b/create_dataset.sh
new file mode 100755
index 0000000..6922c45
--- /dev/null
+++ b/create_dataset.sh
@@ -0,0 +1,19 @@
+#!/bin/sh
+
+set -xe
+
+python3 dataset_maker.py
+uwuify -t $(nproc) messages_to_be_uwuified.txt messages_uwuified.txt
+
+rm -f messages.txt messages_to_be_uwuified.txt
+
+rm -rf dataset
+mkdir -p dataset/{normal,uwu}
+#mv messages_good.txt dataset/normal/normal_text_1.txt
+#mv messages_uwuified.txt dataset/uwu/uwu_text_1.txt
+
+set +x
+python3 split_file.py messages_good.txt dataset/normal $(nproc)
+python3 split_file.py messages_uwuified.txt dataset/uwu $(nproc)
+
+rm messages_good.txt messages_uwuified.txt
diff --git a/dataset_maker.py b/dataset_maker.py
new file mode 100644
index 0000000..fe70193
--- /dev/null
+++ b/dataset_maker.py
@@ -0,0 +1,48 @@
+import json
+import os
+import requests
+
+input_file_path = "train.json"
+
+output_file_path = "messages.txt"
+
+if not os.path.exists(input_file_path):
+ print('Downloading Topical Chat dataset')
+ url = "https://raw.githubusercontent.com/alexa/Topical-Chat/master/conversations/train.json"
+ response = requests.get(url)
+ with open(input_file_path, 'wb') as file:
+ file.write(response.content)
+
+with open(input_file_path, 'r') as json_file, open(output_file_path, 'w') as output_file:
+ data = json.load(json_file)
+
+ for key, value in data.items():
+ if "content" in value:
+ for message_item in value["content"]:
+ if "message" in message_item:
+ message = message_item["message"]
+ output_file.write(message + '\n')
+
+print("Messages extracted and saved to", output_file_path)
+
+# Split 50/50
+input_file_path = "messages.txt"
+
+output_file_path_1 = "messages_to_be_uwuified.txt"
+output_file_path_2 = "messages_good.txt"
+
+with open(input_file_path, 'r') as input_file:
+ messages = input_file.readlines()
+
+split_point = len(messages) // 2
+
+messages_split_1 = messages[:split_point]
+messages_split_2 = messages[split_point:]
+
+with open(output_file_path_1, 'w') as output_file_1:
+ output_file_1.writelines(messages_split_1)
+
+with open(output_file_path_2, 'w') as output_file_2:
+ output_file_2.writelines(messages_split_2)
+
+print("Messages split into two files:", output_file_path_1, "and", output_file_path_2)
diff --git a/interactive.py b/interactive.py
new file mode 100755
index 0000000..1006cdc
--- /dev/null
+++ b/interactive.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+
+import tensorflow as tf
+
+@tf.keras.utils.register_keras_serializable(package='Custom', name=None)
+def text_standardizer(input_data):
+ lowercase = tf.strings.lower(input_data)
+ return lowercase
+
+with tf.keras.utils.CustomObjectScope({'text_standardizer': text_standardizer}):
+ model = tf.keras.models.load_model('final_model')
+ model.summary()
+
+ while True:
+ print(model.predict([input('> ')]))
diff --git a/split_file.py b/split_file.py
new file mode 100755
index 0000000..0b96d65
--- /dev/null
+++ b/split_file.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+import multiprocessing
+
+def split_lines(input_file, output_dir, start, end):
+ with open(input_file, 'r') as input_text_file:
+ lines = input_text_file.readlines()[start:end]
+
+ for index, line in enumerate(lines):
+ output_file = os.path.join(output_dir, f'text_{start + index}.txt')
+ with open(output_file, 'w') as output_text_file:
+ output_text_file.write(line)
+
+def main():
+ if len(sys.argv) != 4:
+ print("Usage: python split_file.py input_file.txt output_directory nprocs")
+ else:
+ input_file = sys.argv[1]
+ output_dir = sys.argv[2]
+ nprocs = int(sys.argv[3])
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ with open(input_file, 'r') as input_text_file:
+ lines = input_text_file.readlines()
+
+ chunk_size = len(lines) // nprocs
+ processes = []
+
+ for i in range(nprocs):
+ start = i * chunk_size
+ end = start + chunk_size if i < nprocs - 1 else len(lines)
+ process = multiprocessing.Process(target=split_lines, args=(input_file, output_dir, start, end))
+ process.start()
+ processes.append(process)
+
+ for process in processes:
+ process.join()
+
+if __name__ == "__main__":
+ main()
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..e894da8
--- /dev/null
+++ b/train.py
@@ -0,0 +1,145 @@
+#!/usr/bin/env python3
+
+import matplotlib.pyplot as plt
+import os
+import re
+import shutil
+import string
+import tensorflow as tf
+
+from tensorflow.keras.saving import register_keras_serializable
+from tensorflow.keras import layers
+from tensorflow.keras import losses
+
+print(tf.__version__)
+
+BATCH_SIZE = 32
+SEED = 69420
+
+print('Loading training dataset')
+raw_train_ds = tf.keras.utils.text_dataset_from_directory(
+ 'dataset',
+ batch_size=BATCH_SIZE,
+ subset='training',
+ seed=SEED,
+ label_mode="int",
+ class_names=['normal', 'uwu'],
+ validation_split=0.2,
+)
+
+print('Loading validation dataset')
+raw_val_ds = tf.keras.utils.text_dataset_from_directory(
+ 'dataset',
+ batch_size=BATCH_SIZE,
+ subset='validation',
+ seed=SEED,
+ label_mode="int",
+ class_names=['normal', 'uwu'],
+ validation_split=0.2,
+)
+
+print('Loading testing dataset')
+raw_test_ds = tf.keras.utils.text_dataset_from_directory(
+ 'dataset',
+ batch_size=BATCH_SIZE,
+ label_mode="int",
+ class_names=['normal', 'uwu'],
+)
+
+@tf.keras.utils.register_keras_serializable(package='Custom', name=None)
+def text_standardizer(input_data):
+ lowercase = tf.strings.lower(input_data)
+ return lowercase
+
+MAX_FEATURES = 10000
+SEQUENCE_LENGTH = 240
+
+vectorize_layer = layers.TextVectorization(
+ standardize=text_standardizer,
+ max_tokens=MAX_FEATURES,
+ output_mode='int',
+ output_sequence_length=SEQUENCE_LENGTH
+ )
+
+train_text = raw_train_ds.map(lambda x, y: x)
+vectorize_layer.adapt(train_text)
+
+def vectorize_text(text, label):
+ text = tf.expand_dims(text, -1)
+ return vectorize_layer(text), label
+
+#text_batch, label_batch = next(iter(raw_train_ds))
+#first_message, first_label = text_batch[0], label_batch[0]
+#print('Message', first_message)
+#print('Label', first_label)
+#print('Vectorized message', vectorize_text(first_message, first_label))
+
+train_ds = raw_train_ds.map(vectorize_text)
+val_ds = raw_val_ds.map(vectorize_text)
+test_ds = raw_test_ds.map(vectorize_text)
+
+AUTOTUNE = tf.data.AUTOTUNE
+
+train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
+val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
+test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)
+
+print('Creating model')
+
+EMBEDDING_DIM = 16
+
+model = tf.keras.Sequential([
+ layers.Embedding(MAX_FEATURES, EMBEDDING_DIM),
+ layers.Dropout(0.2),
+ layers.GlobalAveragePooling1D(),
+ layers.Dropout(0.2),
+ layers.Dense(1)])
+
+model.summary()
+
+model.compile(loss=losses.BinaryCrossentropy(from_logits=True),
+ optimizer='adam',
+ metrics=tf.metrics.BinaryAccuracy(threshold=0.0))
+
+epochs = 10
+history = model.fit(
+ train_ds,
+ validation_data=val_ds,
+ epochs=epochs)
+
+loss, accuracy = model.evaluate(test_ds)
+
+print("Loss: ", loss)
+print("Accuracy: ", accuracy)
+
+history_dict = history.history
+history_dict.keys()
+
+acc = history_dict['binary_accuracy']
+val_acc = history_dict['val_binary_accuracy']
+loss = history_dict['loss']
+val_loss = history_dict['val_loss']
+
+epochs = range(1, len(acc) + 1)
+
+print('Exporting model')
+
+export_model = tf.keras.Sequential([
+ vectorize_layer,
+ model,
+ layers.Activation('sigmoid')
+])
+
+export_model.compile(
+ loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy']
+)
+
+# Test it with `raw_test_ds`, which yields raw strings
+loss, accuracy = export_model.evaluate(raw_test_ds)
+print(accuracy)
+
+print('Saving model')
+export_model.save('final_model', save_format='tf')
+
+while True:
+ export_model.predict([input('> ')])