|
|
- from sklearn.base import clone
- from sklearn.pipeline import Pipeline
- from sklearn import preprocessing
- from sklearn.preprocessing import PolynomialFeatures
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.linear_model import SGDRegressor
- from sklearn.metrics import mean_squared_error
- from sklearn.preprocessing import StandardScaler
- from sklearn.model_selection import train_test_split
-
- np.random.seed(42)
- m = 100
- X = 6 * np.random.rand(m, 1) - 3
- y = 2 + X + 0.5 * X**2 + np.random.randn(m, 1)
-
- plt.plot(X,y,".", label = "Datos originales")
-
- X_train, X_val, y_train, y_val = train_test_split(X[:50], y[:50].ravel(), test_size=0.5, random_state=10)
-
- poly_scaler = Pipeline([
- ("poly_features", PolynomialFeatures(degree=90, include_bias=False)),
- ("std_scaler", StandardScaler()),
- ])
-
- X_train_poly_scaled = poly_scaler.fit_transform(X_train)
- X_val_poly_scaled = poly_scaler.transform(X_val)
-
-
- sgd_reg = SGDRegressor(max_iter=1, tol=-np.infty, warm_start=True, penalty=None,
- learning_rate="constant", eta0=0.0005, random_state=42)
- print(sgd_reg)
- minimum_val_error = float("inf")
- best_epoch = None
- best_model = None
- for epoch in range(1000):
- sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off
- y_val_predict = sgd_reg.predict(X_val_poly_scaled)
- val_error = mean_squared_error(y_val, y_val_predict)
- if val_error < minimum_val_error:
- minimum_val_error = val_error
- best_epoch = epoch
- best_model = clone(sgd_reg)
-
-
- print(best_epoch)
-
-
-
-
-
- sgd_reg = SGDRegressor(max_iter=best_epoch, tol=-np.infty, warm_start=True, penalty=None,
- learning_rate="constant", eta0=0.0005, random_state=42)
-
- poly_features = PolynomialFeatures(degree=2, include_bias=False)
- X_pol = poly_features.fit_transform(X)
-
- sgd_reg.fit(X_pol,y.ravel())
- yout=sgd_reg.predict(X_pol)
-
- plt.plot(X,yout,"*", label = "Predicciones")
-
-
-
- # naming the x axis
- plt.xlabel('Eje X')
- # naming the y axis
- plt.ylabel('Eje Y')
- # giving a title to my graph
-
- plt.legend()
- plt.show()
|