Fit vs. Fit_Transform in Scikit-learn libraries for Machine Learning
We have seen methods such as fit(), transform(), and fit_transform() in a lot of scikit-learn libraries. What do those methods mean? Fitting a model on training data and transforming to test data? Can we use fit_transform() for X_test also? And by using it, would there be any difference in the imputed values than transform()? In this post, we’ll try to understand the difference between them.
fit()
is used to "train model"
- it means to calculate parameters for transformation. Only after training model you can transform()
data. If you want then you can do both things in one step using fit_transform().
Easy Right..? BUT…In the sklearn-python toolbox, transform
and fit_transform
two functions are as follows:
These methods are used to center/feature scale of a given data. It basically helps to normalize the data within a particular range.
For this, we use Z-score method.
- Fit(): Method calculates the parameters μ and σ and saves them as internal objects.
- Transform(): Method applies the values of the parameters on the actual data and gives the normalized value.
- Fit_transform(): joins the fit() and transform() method for transformation of dataset. So all depends on you if you need
fit()
+other calculations
+transform()
or at oncefit_transform()
Code snippets:
import numpy as np
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train_vectorized = model.fit_transform(X_train)
X_test_vectorized = model.transform(X_test)
But why we only transform X_test not fit_transform on it..?
fit_transform means to do some calculation and then do transformation (say calculating the means of columns from some data and then replacing the missing values). So for training set, you need to both calculate and do transformation.
But for testing set, Machine learning applies prediction based on what was learned during the training set and so it doesn’t need to calculate, it just performs the transformation.
If you used fit_transform(X_test) then it will again calculated the μ and σ on a validation set that will be a wrong representation of means and standard deviation of data set as validation set is small as compared to training set.
Hope this will clear your doubt.
To put it simply, you can use the
fit_transform()
method on the training set, as you’ll need to both fit and transform the data, afterwards, you can call itstransform()
method to apply the transformation to validation set.(test data)
Let me know if you have any comments or are not able to understand it.