Browse Source

Metodo de detención temprana

master
parent
commit
70686ae433
1 changed files with 72 additions and 0 deletions
  1. +72
    -0
      Earlystop.py

+ 72
- 0
Earlystop.py View File

@ -0,0 +1,72 @@
from sklearn.base import clone
from sklearn.pipeline import Pipeline
from sklearn import preprocessing
from sklearn.preprocessing import PolynomialFeatures
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
np.random.seed(42)
m = 100
X = 6 * np.random.rand(m, 1) - 3
y = 2 + X + 0.5 * X**2 + np.random.randn(m, 1)
plt.plot(X,y,".", label = "Datos originales")
X_train, X_val, y_train, y_val = train_test_split(X[:50], y[:50].ravel(), test_size=0.5, random_state=10)
poly_scaler = Pipeline([
("poly_features", PolynomialFeatures(degree=90, include_bias=False)),
("std_scaler", StandardScaler()),
])
X_train_poly_scaled = poly_scaler.fit_transform(X_train)
X_val_poly_scaled = poly_scaler.transform(X_val)
sgd_reg = SGDRegressor(max_iter=1, tol=-np.infty, warm_start=True, penalty=None,
learning_rate="constant", eta0=0.0005, random_state=42)
print(sgd_reg)
minimum_val_error = float("inf")
best_epoch = None
best_model = None
for epoch in range(1000):
sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off
y_val_predict = sgd_reg.predict(X_val_poly_scaled)
val_error = mean_squared_error(y_val, y_val_predict)
if val_error < minimum_val_error:
minimum_val_error = val_error
best_epoch = epoch
best_model = clone(sgd_reg)
print(best_epoch)
sgd_reg = SGDRegressor(max_iter=best_epoch, tol=-np.infty, warm_start=True, penalty=None,
learning_rate="constant", eta0=0.0005, random_state=42)
poly_features = PolynomialFeatures(degree=2, include_bias=False)
X_pol = poly_features.fit_transform(X)
sgd_reg.fit(X_pol,y.ravel())
yout=sgd_reg.predict(X_pol)
plt.plot(X,yout,"*", label = "Predicciones")
# naming the x axis
plt.xlabel('Eje X')
# naming the y axis
plt.ylabel('Eje Y')
# giving a title to my graph
plt.legend()
plt.show()

Loading…
Cancel
Save