研究问题:如何根据电影上映前的一些信息来预测出该电影的票房。
数据来源
数据主要是电影数据库中的 7000 多部过去电影的元数据。提供的数据信息包括演员,工作人员,情节关键词,预算,海报,发布日期,语言,制作公司和国家等。
数据导入
#数据导入
import pandas as pd
df = pd.read_csv("https://labfile.oss.aliyuncs.com/courses/1363/TMDB.csv")
df.head()
#数据查看
df.shape
df.info()
df.describe()
df.describe(include=['O'])
#查看一下票房前10的电影
df.sort_values(by='revenue', ascending=False).head(10)[
['title', 'revenue', 'release_date']]
数据预处理
上映时间
release_date为电影的上映时间列,对该列进行处理。将时间中的年、月、日这些信息都分别提取出来。
def date_features(df):
df['release_date'] = pd.to_datetime(df['release_date']) # 转换为时间戳
df['release_year'] = df['release_date'].dt.year # 提取年
df['release_month'] = df['release_date'].dt.month # 提取月
df['release_day'] = df['release_date'].dt.day # 提取日
df['release_quarter'] = df['release_date'].dt.quarter # 提取季度
return df
df = date_features(df)
df['release_year'].head()
检查是否存在异常值,即电影上映时间超过 2019 年,因为收集的数据是 2019 年之前的。
import numpy as np
# 查看大于 2019 的数据
df['release_year'].iloc[np.where(df['release_year'] > 2019)][:10]
#存在异常值,需要处理。大于2019的减去100
df['release_year'] = np.where(
df['release_year'] > 2019, df['release_year']-100, df['release_year'])
df['release_year'].iloc[np.where(df['release_year'] > 2019)][:10]
#已经没有大于2019的了
cols = ['release_year', 'release_month',
'release_day']
df[cols].isnull().sum()
#再查看日期列是否有缺失值,结果显示没有
#其他列还是存在着空值情况的
#显示每个月的平均电影票房
from matplotlib import pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(14, 4))
df.groupby('release_month').agg('mean')['revenue'].plot(kind='bar', rot=0)
plt.ylabel('Revenue (100 million dollars)')
#由图可以看到,电影的上映时间主要集中在 6 月和 12 月。这可能的原因是这两段时间都是假期。
#显示每年的电影平均票房数
release_year_mean_data = df.groupby(['release_year'])['revenue'].mean()
fig = plt.figure(figsize=(14,5))
plt.plot(release_year_mean_data)
plt.ylabel('Mean revenue value')
plt.title('Mean revenue Over Years')
#从上图可以看到,电影的每年平均票房都是逐年递增的,这可能跟我们的经济增长有关,因为人们越来越有钱了,花费在精神上的消费比例也越来越大了。
#每次做出图后,最好给一个结论
和上面的代码类似
#电影时长与年份的折线图
release_year_mean_data = df.groupby(['release_year'])['runtime'].mean()
fig = plt.figure(figsize=(14, 5)) # 设置画布大小
plt.plot(release_year_mean_data)
plt.ylabel('Mean popularity value') # 设置 y 轴的标签
plt.title('Mean popularity Over Years') # 设置标题
#从上图中可以发现,在 1980 年之前,电影的平均时长都是不定的,而 1980 年之后,趋向于稳定,差不多是 100 多分钟。
收藏集
#先打印前5列进行观察
#enumerate在字典上是枚举、列举的意思.利用它可以同时获得索引和值
#如果对一个列表,既要遍历索引又要遍历元素时,用它
for i, e in enumerate(df['belongs_to_collection'][:5]):
print(i, e)
print(type(e))
#通过判断该列的值是否是字符串来判断是否存在值或为空值。
df['belongs_to_collection'].apply(
lambda x: 1 if type(x) == str else 0).value_counts()
#在 3000 份数据中,该列的缺失值就有 2396。
#从该列中提取 name 属性。且创建一列保存是否缺失。
df['collection_name'] = df['belongs_to_collection'].apply(
lambda x: eval(x)[0]['name'] if type(x) == str else 0)
df['has_collection'] = df['belongs_to_collection'].apply(
lambda x: 1 if type(x) == str else 0)
df[['collection_name', 'has_collection']].head()
for i, e in enumerate(df['genres'][:5]):
print(i, e)
#先提取类型名
list_of_genres = list(df['genres'].apply(lambda x:[i['name'] for i in eval(x)] if type(x) == str else []).values)
list_of_genres[:5]
#统计类型的数据
from collections import Counter
most_common_genres = Counter(
[i for j in list_of_genres for i in j]).most_common()
most_common_genres
#绘制图片
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_genres)
names = list(data.keys())
values = list(data.values())
#排序做柱状图
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='teal')
plt.xlabel('Count')
plt.title('Movie Genre Count')
plt.show()
词云图
#先安装 词云库 wordcloud
!pip install wordcloud
from wordcloud import WordCloud
plt.figure(figsize=(12, 8))
text = ' '.join([i for j in list_of_genres for i in j])
# 设置参数
wordcloud = WordCloud(max_font_size=None, background_color='white', collocations=False,
width=1200, height=1000).generate(text)
plt.imshow(wordcloud)
plt.title('Top genres')
plt.axis("off")
plt.show()
df['num_genres'] = df['genres'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_genres'] = df['genres'].apply(lambda x: ' '.join(
sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_genres = [m[0] for m in Counter(
[i for j in list_of_genres for i in j]).most_common(15)]
for g in top_genres:
df['genre_' + g] = df['all_genres'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'genre_' in str(i)]
df[cols].head()
import plotly.graph_objs as go
import plotly.offline as py
py.init_notebook_mode(connected=False)
drama = df.loc[df['genre_Drama'] == 1, ] # 得到所有电影类型为 Drama 的数据
comedy = df.loc[df['genre_Comedy'] == 1, ]
action = df.loc[df['genre_Action'] == 1, ]
thriller = df.loc[df['genre_Thriller'] == 1, ]
drama_revenue = drama.groupby(['release_year']).mean()['revenue'] # 求出票房的平均值
comedy_revenue = comedy.groupby(['release_year']).mean()['revenue']
action_revenue = action_revenue = action.groupby(
['release_year']).mean()['revenue']
thriller_revenue = thriller.groupby(['release_year']).mean()['revenue']
revenue_concat = pd.concat([drama_revenue, # 将数据合并为一份
comedy_revenue,
action_revenue,
thriller_revenue],
axis=1)
revenue_concat.columns = ['drama', 'comedy', 'action', 'thriller']
revenue_concat.index = df.groupby(['release_year']).mean().index
data = [go.Scatter(x=revenue_concat.index, y=revenue_concat.drama, name='drama'),
go.Scatter(x=revenue_concat.index,
y=revenue_concat.comedy, name='comedy'),
go.Scatter(x=revenue_concat.index,
y=revenue_concat.action, name='action'),
go.Scatter(x=revenue_concat.index, y=revenue_concat.thriller, name='thriller')]
# 画出图形
layout = go.Layout(dict(title='Mean Revenue by Top 4 Movie Genres Over Years',
xaxis=dict(title='Year'),
yaxis=dict(title='Revenue'),
), legend=dict(
orientation="v"))
py.iplot(dict(data=data, layout=layout))
制片公司
for i, e in enumerate(df['production_companies'][:5]):
print(i, e)
list_of_companies = list(df['production_companies'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
# 得到每个公司的电影发行量
most_common_companies = Counter(
[i for j in list_of_companies for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_companies)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='brown')
plt.xlabel('Count')
plt.title('Top 20 Production Company Count')
plt.show()
#从该列中提取一些重要的信息
df['num_companies'] = df['production_companies'].apply(
lambda x: len(x) if type(x) == str else 0)
df['all_production_companies'] = df['production_companies'].apply(
lambda x: ' '.join(sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_companies = [m[0] for m in Counter(
[i for j in list_of_companies for i in j]).most_common(30)]
for g in top_companies:
df['production_company_' +
g] = df['all_production_companies'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'production_company' in str(i)]
df[cols].head()
Warner_Bros = df.loc[df['production_company_Warner Bros.'] == 1, ]
Universal_Pictures = df.loc[df['production_company_Universal Pictures'] == 1, ]
Twentieth_Century_Fox_Film = df.loc[df['production_company_Twentieth Century Fox Film Corporation'] == 1, ]
Columbia_Pictures = df.loc[df['production_company_Columbia Pictures'] == 1, ]
#画出几个公司制作的电影票房数量
Warner_Bros_revenue = Warner_Bros.groupby(['release_year']).mean()['revenue']
Universal_Pictures_revenue = Universal_Pictures.groupby(
['release_year']).mean()['revenue']
Twentieth_Century_Fox_Film_revenue = Twentieth_Century_Fox_Film.groupby(
['release_year']).mean()['revenue']
Columbia_Pictures_revenue = Columbia_Pictures.groupby(
['release_year']).mean()['revenue']
prod_revenue_concat = pd.concat([Warner_Bros_revenue,
Universal_Pictures_revenue,
Twentieth_Century_Fox_Film_revenue,
Columbia_Pictures_revenue], axis=1)
prod_revenue_concat.columns = ['Warner_Bros',
'Universal_Pictures',
'Twentieth_Century_Fox_Film',
'Columbia_Pictures']
fig = plt.figure(figsize=(13, 5))
prod_revenue_concat.agg("mean", axis='rows').sort_values(ascending=True).plot(kind='barh',
x='Production Companies',
y='Revenue',
title='Mean Revenue (100 million dollars) of Most Common Production Companies')
plt.xlabel('Revenue (100 million dollars)')
data = [go.Scatter(x=prod_revenue_concat.index, y=prod_revenue_concat.Warner_Bros, name='Warner_Bros'),
go.Scatter(x=prod_revenue_concat.index,
y=prod_revenue_concat.Universal_Pictures, name='Universal_Pictures'),
go.Scatter(x=prod_revenue_concat.index,
y=prod_revenue_concat.Twentieth_Century_Fox_Film, name='Twentieth_Century_Fox_Film'),
go.Scatter(x=prod_revenue_concat.index, y=prod_revenue_concat.Columbia_Pictures, name='Columbia_Pictures'), ]
layout = go.Layout(dict(title='Mean Revenue of Movie Production Companies over Years',
xaxis=dict(title='Year'),
yaxis=dict(title='Revenue'),
), legend=dict(
orientation="v"))
py.iplot(dict(data=data, layout=layout))
出版国家
for i, e in enumerate(df['production_countries'][:5]):
print(i, e)
list_of_countries = list(df['production_countries'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
most_common_countries = Counter(
[i for j in list_of_countries for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_countries)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='purple')
plt.xlabel('Count')
plt.title('Country Count')
plt.show()
#对电影出产国家进行特征提取
df['num_countries'] = df['production_countries'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_countries'] = df['production_countries'].apply(lambda x: ' '.join(
sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_countries = [m[0] for m in Counter(
[i for j in list_of_countries for i in j]).most_common(25)]
for g in top_countries:
df['production_country_' +
g] = df['all_countries'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'production_country' in str(i)]
df[cols].head()
电影语言
for i, e in enumerate(df['spoken_languages'][:5]):
print(i, e)
list_of_languages = list(df['spoken_languages'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
most_common_languages = Counter(
[i for j in list_of_languages for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_languages)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True), values, tick_label=names)
plt.xlabel('Count')
plt.title('Language Count')
plt.show()
#对语言提取特征。
df['num_languages'] = df['spoken_languages'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_languages'] = df['spoken_languages'].apply(lambda x: ' '.join(
sorted([i['iso_639_1'] for i in eval(x)])) if type(x) == str else '')
top_languages = [m[0] for m in Counter(
[i for j in list_of_languages for i in j]).most_common(30)]
for g in top_languages:
df['language_' +
g] = df['all_languages'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'language_' in str(i)]
df[cols].head()
for i, e in enumerate(df['Keywords'][:5]):
print(i, e)
#对关键词进行统计
list_of_keywords = list(df['Keywords'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
most_common_keywords = Counter(
[i for j in list_of_keywords for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_keywords)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='purple')
plt.xlabel('Count')
plt.title('Top 20 Most Common Keyword Count')
plt.show()
做词云图
text_drama = " ".join(review for review in drama['Keywords'].apply(
lambda x: ' '.join(sorted([i['name'] for i in eval(x)])) if type(x) == str else ''))
text_comedy = " ".join(review for review in comedy['Keywords'].apply(
lambda x: ' '.join(sorted([i['name'] for i in eval(x)])) if type(x) == str else ''))
text_action = " ".join(review for review in action['Keywords'].apply(
lambda x: ' '.join(sorted([i['name'] for i in eval(x)])) if type(x) == str else ''))
text_thriller = " ".join(review for review in thriller['Keywords'].apply(
lambda x: ' '.join(sorted([i['name'] for i in eval(x)])) if type(x) == str else ''))
wordcloud1 = WordCloud(background_color="white",
colormap="Reds").generate(text_drama)
wordcloud2 = WordCloud(background_color="white",
colormap="Blues").generate(text_comedy)
wordcloud3 = WordCloud(background_color="white",
colormap="Greens").generate(text_action)
wordcloud4 = WordCloud(background_color="white",
colormap="Greys").generate(text_thriller)
fig = plt.figure(figsize=(25, 20))
plt.subplot(221)
plt.imshow(wordcloud1, interpolation='bilinear')
plt.title('Drama Keywords')
plt.axis("off")
plt.subplot(222)
plt.imshow(wordcloud2, interpolation='bilinear')
plt.title('Comedy Keywords')
plt.axis("off")
plt.show()
fig = plt.figure(figsize=(25, 20))
plt.subplot(223)
plt.imshow(wordcloud3, interpolation='bilinear')
plt.title('Action Keywords')
plt.axis("off")
plt.subplot(224)
plt.imshow(wordcloud4, interpolation='bilinear')
plt.title('Thriller Keywords')
plt.axis("off")
plt.show()
对该列进行特征提取
df['num_Keywords'] = df['Keywords'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_Keywords'] = df['Keywords'].apply(lambda x: ' '.join(
sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_keywords = [m[0] for m in Counter(
[i for j in list_of_keywords for i in j]).most_common(30)]
for g in top_keywords:
df['keyword_' + g] = df['all_Keywords'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'keyword_' in str(i)]
df[cols].head()
演员
for i, e in enumerate(df['cast'][:1]):
print(i, e)
#统计一下哪些演员演过的电影最多。
list_of_cast_names = list(df['cast'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
most_common_keywords = Counter(
[i for j in list_of_cast_names for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_keywords)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='purple')
plt.xlabel('Count')
plt.title('Top 20 Most Common Keyword Count')
plt.show()
提取特征
df['num_cast'] = df['cast'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_cast'] = df['cast'].apply(lambda x: ' '.join(
sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_cast_names = [m[0] for m in Counter(
[i for j in list_of_cast_names for i in j]).most_common(30)]
for g in top_cast_names:
df['cast_name_' + g] = df['all_cast'].apply(lambda x: 1 if g in x else 0)
cols = [i for i in df.columns if 'cast_name' in str(i)]
df[cols].head()
画出参演数量最多的演员所获得的电影票房情况。
cast_name_Samuel_L_Jackson = df.loc[df['cast_name_Samuel L. Jackson'] == 1, ]
cast_name_Robert_De_Niro = df.loc[df['cast_name_Robert De Niro'] == 1, ]
cast_name_Morgan_Freeman = df.loc[df['cast_name_Morgan Freeman'] == 1, ]
cast_name_J_K_Simmons = df.loc[df['cast_name_J.K. Simmons'] == 1, ]
cast_name_Samuel_L_Jackson_revenue = cast_name_Samuel_L_Jackson.mean()[
'revenue']
cast_name_Robert_De_Niro_revenue = cast_name_Robert_De_Niro.mean()['revenue']
cast_name_Morgan_Freeman_revenue = cast_name_Morgan_Freeman.mean()['revenue']
cast_name_J_K_Simmons_revenue = cast_name_J_K_Simmons.mean()['revenue']
cast_revenue_concat = pd.Series([cast_name_Samuel_L_Jackson_revenue,
cast_name_Robert_De_Niro_revenue,
cast_name_Morgan_Freeman_revenue,
cast_name_J_K_Simmons_revenue])
cast_revenue_concat.index = ['Samuel L. Jackson',
'Robert De Niro',
'Morgan Freeman',
'J.K. Simmons', ]
fig = plt.figure(figsize=(13, 5))
cast_revenue_concat.sort_values(ascending=True).plot(
kind='barh', title='Mean Revenue (100 million dollars) by Top 4 Most Common Cast')
plt.xlabel('Revenue (100 million dollars)')
对演员性别等特征进行提取
list_of_cast_genders = list(df['cast'].apply(
lambda x: [i['gender'] for i in eval(x)] if type(x) == str else []).values)
list_of_cast_characters = list(df['cast'].apply(
lambda x: [i['character'] for i in eval(x)] if type(x) == str else []).values)
df['genders_0'] = sum([1 for i in list_of_cast_genders if i == 0])
df['genders_1'] = sum([1 for i in list_of_cast_genders if i == 1])
df['genders_2'] = sum([1 for i in list_of_cast_genders if i == 2])
top_cast_characters = [m[0] for m in Counter(
[i for j in list_of_cast_characters for i in j]).most_common(15)]
for g in top_cast_characters:
df['cast_character_' +
g] = df['cast'].apply(lambda x: 1 if type(x) == str and g in x else 0)
cols = [i for i in df.columns if 'cast_cha' in str(i)]
df[cols].head()
制作团队
for i, e in enumerate(df['crew'][:1]):
print(i, e)
#统计一下团队人物制作的电影数量。
list_of_crew_names = list(df['crew'].apply(
lambda x: [i['name'] for i in eval(x)] if type(x) == str else []).values)
most_common_keywords = Counter(
[i for j in list_of_crew_names for i in j]).most_common(20)
fig = plt.figure(figsize=(10, 6))
data = dict(most_common_keywords)
names = list(data.keys())
values = list(data.values())
plt.barh(sorted(range(len(data)), reverse=True),
values, tick_label=names, color='purple')
plt.xlabel('Count')
plt.title('Top 20 Most Common Keyword Count')
plt.show()
#进行特征提取。
df['num_crew'] = df['crew'].apply(
lambda x: len(eval(x)) if type(x) == str else 0)
df['all_crew'] = df['crew'].apply(lambda x: ' '.join(
sorted([i['name'] for i in eval(x)])) if type(x) == str else '')
top_crew_names = [m[0] for m in Counter(
[i for j in list_of_crew_names for i in j]).most_common(30)]
for g in top_crew_names:
df['crew_name_' +
g] = df['all_crew'].apply(lambda x: 1 if type(x) == str and g in x else 0)
cols = [i for i in df.columns if 'crew_name' in str(i)]
df[cols].head()
对排名前 4 位导演进行分析
crew_name_Avy_Kaufman = df.loc[df['crew_name_Avy Kaufman'] == 1, ]
crew_name_Robert_Rodriguez = df.loc[df['crew_name_Robert Rodriguez'] == 1, ]
crew_name_Deborah_Aquila = df.loc[df['crew_name_Deborah Aquila'] == 1, ]
crew_name_James_Newton_Howard = df.loc[df['crew_name_James Newton Howard'] == 1, ]
crew_name_Avy_Kaufman_revenue = crew_name_Avy_Kaufman.mean()['revenue']
crew_name_Robert_Rodriguez_revenue = crew_name_Robert_Rodriguez.mean()[
'revenue']
crew_name_Deborah_Aquila_revenue = crew_name_Deborah_Aquila.mean()['revenue']
crew_name_James_Newton_Howard_revenue = crew_name_James_Newton_Howard.mean()[
'revenue']
crew_revenue_concat = pd.Series([crew_name_Avy_Kaufman_revenue,
crew_name_Robert_Rodriguez_revenue,
crew_name_Deborah_Aquila_revenue,
crew_name_James_Newton_Howard_revenue])
crew_revenue_concat.index = ['Avy Kaufman',
'Robert Rodriguez',
'Deborah Aquila',
'James Newton Howard']
fig = plt.figure(figsize=(13, 5))
crew_revenue_concat.sort_values(ascending=True).plot(
kind='barh', title='Mean Revenue (100 million dollars) by Top 10 Most Common Crew')
plt.xlabel('Revenue (100 million dollars)')
特征工程
fig = plt.figure(figsize=(15, 10))
plt.subplot(221)
df['revenue'].plot(kind='hist', bins=100)
plt.title('Distribution of Revenue')
plt.xlabel('Revenue')
plt.subplot(222)
np.log1p(df['revenue']).plot(kind='hist', bins=100)
plt.title('Train Log Revenue Distribution')
plt.xlabel('Log Revenue')
对预计票房列同样操作
fig = plt.figure(figsize=(15, 10))
plt.subplot(221)
df['budget'].plot(kind='hist', bins=100)
plt.title('Train Budget Distribution')
plt.xlabel('Budget')
plt.subplot(222)
np.log1p(df['budget']).plot(kind='hist', bins=100)
plt.title('Train Log Budget Distribution')
plt.xlabel('Log Budget')
plt.show()
drop_columns = ['homepage', 'imdb_id', 'poster_path', 'status',
'title', 'release_date', 'tagline', 'overview',
'original_title', 'all_genres', 'all_cast',
'original_language', 'collection_name', 'all_crew',
'belongs_to_collection', 'genres', 'production_companies',
'all_production_companies', 'production_countries',
'all_countries', 'spoken_languages', 'all_languages',
'Keywords', 'all_Keywords', 'cast', 'crew']
df_drop = df.drop(drop_columns, axis=1).dropna(axis=1, how='any')
df_drop.head()
from sklearn.model_selection import train_test_split
data_X = df_drop.drop(['id', 'revenue'], axis=1)
data_y = np.log1p(df_drop['revenue'])
train_X, test_X, train_y, test_y = train_test_split(
data_X, data_y.values, test_size=0.2)
预测模型
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error
model = Lasso()
model.fit(train_X, train_y) # 构建模型
y_pred = model.predict(test_X) # 训练模型
mean_squared_error(y_pred, test_y) # 预测模型
#Lasso 回归的预测结果与真实值的均方差为 6 到 7 左右
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
model = Ridge()
model.fit(train_X, train_y)
y_pred = model.predict(test_X)
mean_squared_error(y_pred, test_y)
#方差为6.5 Ridge 回归要相比 Lasso 回归要好一点。
数据预处理(手动提取特征,并可视化)——特征工程(对原始数据的票房列和预估列进行平滑)——构建预测模型( Lasso 模型和 Ridge 模型)
原文地址:https://blog.csdn.net/Sun123234/article/details/125418140
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_31174.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。