ডিপ লার্নিং দিয়ে ডিসিশন বাউন্ডারি, চাঁদ আকারের ডাটাসেট
(পুরো চ্যাপ্টার একটা নোটবুক)
Last updated
(পুরো চ্যাপ্টার একটা নোটবুক)
Last updated
আমরা বই পড়ছি, নোটবুক কেন পড়বো?
যেহেতু গিটবুকে নোটবুক ঠিকমতো রেন্ডার হয়না, সেকারণে গুগল কোলাব এবং গিটহাবে দেখা উচিৎ। গিটহাব লিংক: https://github.com/raqueeb/TensorFlow2/blob/master/deep_learning_decision_boundary.ipynb এবং গুগল কোলাব লিংক: https://colab.research.google.com/github/raqueeb/TensorFlow2/blob/master/deep_learning_decision_boundary.ipynb
যেকোনো প্রজেক্ট এর কাজের শুরুতে আমার একটা অভ্যাস হচ্ছে সেই ডাটাসেটকে প্লট করে দেখা। একটা ডাটাসেটের ঠিকমতো প্লটিং দেখতে পারলে তার সমস্যার অনেকটাইকুল কিনারা করা যায়। এজন্যই ডাটা ভিজুয়ালাইজেশন এতটাই দরকারি। আমরা যখন সাইকিট-লার্ন নিয়ে কাজ করছিলাম তখন এধরনের ডাটার অনেক প্লটিং দেখেছিলাম।
চিত্রঃ বিভিন্ন ধরনের ডেটার প্লটিং
%matplotlib inline
import matplotlib.pyplot as plt
ডাক্তার প্লটিং এর পাশাপাশি প্রয়োজনে দরকারি ডাটাসেট না পেলে অনেক ধরনের ডাটাসেট স্যাম্পল জেনারেট করতে হয় যা আসলে আমাদের দরকারি সাইজ এবং কম্প্লেক্সিটি বুঝেই করতে হয়। এই ধরনের জেনারেটর আমাদের ফিচার এবং তার করেসপন্ডিং ‘ডিসক্রিট’ টার্গেটের একটা ম্যাট্রিক্স তৈরি করে দেয়। পাশাপাশি কিছু নয়েজও ঢোকাতে হয় রিয়ালিস্টিক করার জন্য। লিনিয়ার ডাটাসেটের পাশাপাশি কিছু নন-লিনিয়ার ডাটা নিয়ে কাজ করতে গেলে আলাদা জেনারেটর ব্যবহার করি আমরা। যেমন, দুইটা বৃত্তাংশ একটা আরেকটার মধ্যে ঢুকে গেছে। দুটা অর্ধেক চাঁদ। সাইকিট-লার্নে এটার একটা ফাংশন আছে, দেখুন এই উদাহরণে।
আমাদের কাজ হবে এই ধরনের একটা নন-লিনিয়ার ডাটাসেটে কিভাবে ডিপ লার্নিং দিয়ে সুন্দর ক্লাসিফিকেশন করা যায়। এর আগেও করেছি, তবে এই ডেটা জেনারেটর দিয়ে তৈরি, এবং আরো কমপ্লেক্স। একে sklearn.datasets থেকে ইমপোর্ট করছি এই নতুন ফাংশন দিয়ে। ফিচার এবং টার্গেট ভ্যারিয়েবল হিসেবে দুটো আলাদা আলাদা স্যাম্পল জেনারেট করছি। এরপর প্লটিং।
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=1000, noise=0.1, random_state=0)
plt.plot(X[y==0, 0], X[y==0, 1], 'ob', alpha=0.5)
plt.plot(X[y==1, 0], X[y==1, 1], 'xr', alpha=0.5)
plt.legend(['0', '1'])
স্যাম্পল এর সংখ্যা দেখি।
X.shape
(1000, 2)
টেস্ট এবং ট্রেনিং ডাটা স্প্লিট করি।
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.3,
random_state=42)
try:
# %tensorflow_version only exists in Colab.
# শুধুমাত্র জুপিটার নোটবুক/কোলাবে চেষ্টা করবো টেন্সর-ফ্লো ২.০ এর জন্য
%tensorflow_version 2.x
except Exception:
pass
TensorFlow 2.x selected.
import tensorflow as tf
এরপর আমরা এর ইনপুটকে এক লাইনের/লেয়ারের নিউরাল নেটওয়ার্কের দিয়ে দেই। কি ঘটছে?
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(1, input_dim=2, activation='sigmoid')
])
model.compile(tf.keras.optimizers.Adam(lr=0.05), 'binary_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=200, verbose=0)
WARNING:tensorflow:From /tensorflow-2.0.0-rc2/python3.6/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
<tensorflow.python.keras.callbacks.History at 0x7f21ab9359b0>
মডেলকে ইভালুয়েট করে দেখি অ্যাক্যুরেসি কত এসেছে। লস এবং অ্যাক্যুরেসি দেখতে পাচ্ছি এখানে।
results = model.evaluate(X_test, y_test)
300/1 [========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 268us/sample - loss: 0.2587 - accuracy: 0.8367
results
[0.31679999371369677, 0.83666664]
ডাটা প্লটিং এর পাশাপাশি আমরা ডিসিশন বাউন্ডারি দেখার চেষ্টা করি। এখানে ব্যবহার করেছি mlxtend.plotting এর মতো হেল্পার লাইব্রেরি। এর plot_decision_regions কাজ হচ্ছে মডেল থেকে যা ইনপুট দেয় সেটা সে প্লট করে দেয়। যেমন এখানে ডিসিশন বাউন্ডারি করে দিয়েছে ঠিকই তবে সেটা কাজ করছেনা। সত্যিই তো, এখানকার ডিসিশন বাউন্ডারি লিনিয়ার, কিন্তু ডাটা তো নন-লিনিয়ার।
from mlxtend.plotting import plot_decision_regions
plot_decision_regions(X=X, y=y, clf=model, legend=2)
plt.xlabel("x", size=5)
plt.ylabel("y", size=5)
plt.title('Plot Decision Region Boundary', size=10)
plt.show()
আমরা একটু তিন লাইনের একটা ডিপ লার্নিং মডেল থেকে আউটকাম দেখার চেষ্টা করি। ৩ লেয়ার, পাশাপাশি নিউরনের সংখ্যাও বেশি। কি ঘটছে এখন?
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape=(2,), activation='tanh'))
model.add(tf.keras.layers.Dense(2, activation='tanh'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(tf.keras.optimizers.Adam(lr=0.05), 'binary_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=100, verbose=0)
model.evaluate(X_test, y_test)
300/1 [========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 325us/sample - loss: 5.6714e-04 - accuracy: 1.0000
[0.0009398714792526637, 1.0]
from sklearn.metrics import accuracy_score
দেখতে পাচ্ছি অ্যাকুরেসি বেড়েছে।
y_train_pred = model.predict_classes(X_train)
y_test_pred = model.predict_classes(X_test)
print("The Accuracy score on the Train set is:\t{:0.3f}".format(accuracy_score(y_train, y_train_pred)))
print("The Accuracy score on the Test set is:\t{:0.3f}".format(accuracy_score(y_test, y_test_pred)))
The Accuracy score on the Train set is: 0.999
The Accuracy score on the Test set is: 1.000
নতুন করে ডিসিশন বাউন্ডারি প্লটিং করে দেখি। অসাধারণ কাজ হয়েছে।
plot_decision_regions(X=X, y=y, clf=model, legend=2)
plt.xlabel("x", size=5)
plt.ylabel("y", size=5)
plt.title('Plot Decision Region Boundary', size=10)
plt.show()